Enhance Neo4jChatMemoryRepository, expand integration tests, improve configuration, and update documentation
- Enhanced Neo4jChatMemoryRepository to correctly restore custom metadata for SystemMessage using SystemMessage.Builder. - Refactored and clarified Neo4jChatMemoryRepository implementation code. - Added comprehensive integration tests for Neo4jChatMemoryConfig and Neo4jChatMemoryRepository, including: -- Index creation verification -- Custom label support -- Getter validation for all configuration properties -- Tests for saving and retrieving SystemMessage metadata -- Tests ensuring saveAll(conversationId, Collections.emptyList()) clears all messages and removes the conversation node -- Tests for handling of messages with empty content and empty metadata -- Improved overall test coverage for Neo4j persistence and configuration edge cases - Fixed resource management bugs in test classes (ensured proper driver/session closure). - Improved index creation logic in Neo4jChatMemoryConfig for reliability and logging. - Updated documentation to include Neo4jChatMemoryRepository usage and configuration Signed-off-by: enricorampazzo <enrico.rampazzo@live.com> Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
This commit is contained in:
committed by
Mark Pollack
parent
4ec067641b
commit
43aa8939d2
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright 2024-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.model.chat.memory.jdbc.autoconfigure;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testcontainers.containers.PostgreSQLContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Jonathan Leijendekker
|
||||
*/
|
||||
@Testcontainers
|
||||
class JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests {
|
||||
|
||||
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17");
|
||||
|
||||
@Container
|
||||
@SuppressWarnings("resource")
|
||||
static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME)
|
||||
.withDatabaseName("chat_memory_initializer_test")
|
||||
.withUsername("postgres")
|
||||
.withPassword("postgres");
|
||||
|
||||
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
|
||||
.withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class,
|
||||
JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class))
|
||||
.withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()),
|
||||
String.format("spring.datasource.username=%s", postgresContainer.getUsername()),
|
||||
String.format("spring.datasource.password=%s", postgresContainer.getPassword()));
|
||||
|
||||
@Test
|
||||
void getSettings_shouldHaveSchemaLocations() {
|
||||
this.contextRunner.run(context -> {
|
||||
var dataSource = context.getBean(DataSource.class);
|
||||
var settings = JdbcChatMemoryDataSourceScriptDatabaseInitializer.getSettings(dataSource);
|
||||
|
||||
assertThat(settings.getSchemaLocations())
|
||||
.containsOnly("classpath:org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.neo4j.autoconfigure;
|
||||
|
||||
import org.neo4j.driver.Driver;
|
||||
|
||||
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
|
||||
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
|
||||
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository;
|
||||
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
@@ -29,19 +29,19 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
/**
|
||||
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}.
|
||||
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemoryRepository}.
|
||||
*
|
||||
* @author Enrico Rampazzo
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@AutoConfiguration(after = Neo4jAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class)
|
||||
@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class })
|
||||
@ConditionalOnClass({ Neo4jChatMemoryRepository.class, Driver.class })
|
||||
@EnableConfigurationProperties(Neo4jChatMemoryProperties.class)
|
||||
public class Neo4jChatMemoryAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) {
|
||||
public Neo4jChatMemoryRepository chatMemoryRepository(Neo4jChatMemoryProperties properties, Driver driver) {
|
||||
|
||||
var builder = Neo4jChatMemoryConfig.builder()
|
||||
.withMediaLabel(properties.getMediaLabel())
|
||||
@@ -52,7 +52,7 @@ public class Neo4jChatMemoryAutoConfiguration {
|
||||
.withToolResponseLabel(properties.getToolResponseLabel())
|
||||
.withDriver(driver);
|
||||
|
||||
return Neo4jChatMemory.create(builder.build());
|
||||
return new Neo4jChatMemoryRepository(builder.build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,12 +23,14 @@ import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
|
||||
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
@@ -51,7 +53,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@Testcontainers
|
||||
class Neo4jChatMemoryAutoConfigurationIT {
|
||||
class Neo4jChatMemoryRepositoryAutoConfigurationIT {
|
||||
|
||||
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");
|
||||
|
||||
@@ -67,31 +69,31 @@ class Neo4jChatMemoryAutoConfigurationIT {
|
||||
@Test
|
||||
void addAndGet() {
|
||||
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()).run(context -> {
|
||||
Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class);
|
||||
ChatMemoryRepository memory = context.getBean(ChatMemoryRepository.class);
|
||||
|
||||
String sessionId = UUID.randomUUID().toString();
|
||||
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
|
||||
assertThat(memory.findByConversationId(sessionId)).isEmpty();
|
||||
|
||||
UserMessage userMessage = new UserMessage("test question");
|
||||
|
||||
memory.add(sessionId, userMessage);
|
||||
List<Message> messages = memory.get(sessionId, Integer.MAX_VALUE);
|
||||
memory.saveAll(sessionId, List.of(userMessage));
|
||||
List<Message> messages = memory.findByConversationId(sessionId);
|
||||
assertThat(messages).hasSize(1);
|
||||
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage);
|
||||
|
||||
memory.clear(sessionId);
|
||||
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
|
||||
memory.deleteByConversationId(sessionId);
|
||||
assertThat(memory.findByConversationId(sessionId)).isEmpty();
|
||||
|
||||
AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(),
|
||||
List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments")));
|
||||
|
||||
memory.add(sessionId, List.of(userMessage, assistantMessage));
|
||||
messages = memory.get(sessionId, Integer.MAX_VALUE);
|
||||
memory.saveAll(sessionId, List.of(userMessage, assistantMessage));
|
||||
messages = memory.findByConversationId(sessionId);
|
||||
assertThat(messages).hasSize(2);
|
||||
assertThat(messages.get(1)).isEqualTo(userMessage);
|
||||
assertThat(messages.get(0)).isEqualTo(userMessage);
|
||||
|
||||
assertThat(messages.get(0)).isEqualTo(assistantMessage);
|
||||
memory.clear(sessionId);
|
||||
assertThat(messages.get(1)).isEqualTo(assistantMessage);
|
||||
memory.deleteByConversationId(sessionId);
|
||||
MimeType textPlain = MimeType.valueOf("text/plain");
|
||||
List<Media> media = List.of(
|
||||
Media.builder()
|
||||
@@ -102,28 +104,28 @@ class Neo4jChatMemoryAutoConfigurationIT {
|
||||
.build(),
|
||||
Media.builder().data(URI.create("http://www.google.com")).mimeType(textPlain).build());
|
||||
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
|
||||
memory.add(sessionId, userMessageWithMedia);
|
||||
memory.saveAll(sessionId, List.of(userMessageWithMedia));
|
||||
|
||||
messages = memory.get(sessionId, Integer.MAX_VALUE);
|
||||
messages = memory.findByConversationId(sessionId);
|
||||
assertThat(messages.size()).isEqualTo(1);
|
||||
assertThat(messages.get(0)).isEqualTo(userMessageWithMedia);
|
||||
assertThat(((UserMessage) messages.get(0)).getMedia()).hasSize(2);
|
||||
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator()
|
||||
.isEqualTo(media);
|
||||
memory.clear(sessionId);
|
||||
memory.deleteByConversationId(sessionId);
|
||||
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(
|
||||
List.of(new ToolResponse("id", "name", "responseData"),
|
||||
new ToolResponse("id2", "name2", "responseData2")),
|
||||
Map.of("id", "id", "metadataKey", "metadata"));
|
||||
memory.add(sessionId, toolResponseMessage);
|
||||
messages = memory.get(sessionId, Integer.MAX_VALUE);
|
||||
memory.saveAll(sessionId, List.of(toolResponseMessage));
|
||||
messages = memory.findByConversationId(sessionId);
|
||||
assertThat(messages.size()).isEqualTo(1);
|
||||
assertThat(messages.get(0)).isEqualTo(toolResponseMessage);
|
||||
|
||||
memory.clear(sessionId);
|
||||
memory.deleteByConversationId(sessionId);
|
||||
SystemMessage sm = new SystemMessage("this is a System message");
|
||||
memory.add(sessionId, sm);
|
||||
messages = memory.get(sessionId, Integer.MAX_VALUE);
|
||||
memory.saveAll(sessionId, List.of(sm));
|
||||
messages = memory.findByConversationId(sessionId);
|
||||
assertThat(messages).hasSize(1);
|
||||
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm);
|
||||
});
|
||||
@@ -148,7 +150,7 @@ class Neo4jChatMemoryAutoConfigurationIT {
|
||||
propertyBase.formatted("toolresponselabel", toolResponseLabel),
|
||||
propertyBase.formatted("medialabel", mediaLabel))
|
||||
.run(context -> {
|
||||
Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class);
|
||||
Neo4jChatMemoryRepository chatMemory = context.getBean(Neo4jChatMemoryRepository.class);
|
||||
Neo4jChatMemoryConfig config = chatMemory.getConfig();
|
||||
assertThat(config.getMessageLabel()).isEqualTo(messageLabel);
|
||||
assertThat(config.getMediaLabel()).isEqualTo(mediaLabel);
|
||||
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.model.embedding.observation.autoconfigure;
|
||||
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
/**
|
||||
* Auto-configuration for Spring AI embedding model observations.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@AutoConfiguration(
|
||||
afterName = "org.springframework.boot.actuate.autoconfigure.observation.ObservationAutoConfiguration")
|
||||
@ConditionalOnClass(EmbeddingModel.class)
|
||||
public class EmbeddingObservationAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnBean(MeterRegistry.class)
|
||||
EmbeddingModelMeterObservationHandler embeddingModelMeterObservationHandler(
|
||||
ObjectProvider<MeterRegistry> meterRegistry) {
|
||||
return new EmbeddingModelMeterObservationHandler(meterRegistry.getObject());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.model.embedding.observation.autoconfigure;
|
||||
|
||||
import io.micrometer.core.instrument.composite.CompositeMeterRegistry;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link EmbeddingObservationAutoConfiguration}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class EmbeddingObservationAutoConfigurationTests {
|
||||
|
||||
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
|
||||
.withConfiguration(AutoConfigurations.of(EmbeddingObservationAutoConfiguration.class));
|
||||
|
||||
@Test
|
||||
void meterObservationHandlerEnabled() {
|
||||
this.contextRunner.withBean(CompositeMeterRegistry.class)
|
||||
.run(context -> assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void meterObservationHandlerDisabled() {
|
||||
this.contextRunner
|
||||
.run(context -> assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -50,6 +50,7 @@
|
||||
<artifactId>spring-data-neo4j</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- TESTING -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
@@ -68,6 +69,12 @@
|
||||
<artifactId>spring-boot-testcontainers</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>testcontainers</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.neo4j.driver</groupId>
|
||||
@@ -79,6 +86,12 @@
|
||||
<artifactId>neo4j</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
||||
@@ -93,6 +93,33 @@ public final class Neo4jChatMemoryConfig {
|
||||
this.toolCallLabel = builder.toolCallLabel;
|
||||
this.metadataLabel = builder.metadataLabel;
|
||||
this.toolResponseLabel = builder.toolResponseLabel;
|
||||
ensureIndexes();
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that indexes exist on conversationId for Session nodes and index for
|
||||
* Message nodes. This improves query performance for lookups and ordering.
|
||||
*/
|
||||
private void ensureIndexes() {
|
||||
if (this.driver == null) {
|
||||
logger.warn("Neo4j Driver is null, cannot ensure indexes.");
|
||||
return;
|
||||
}
|
||||
try (var session = this.driver.session()) {
|
||||
// Index for conversationId on Session nodes
|
||||
String sessionIndexCypher = String.format(
|
||||
"CREATE INDEX session_conversation_id_index IF NOT EXISTS FOR (n:%s) ON (n.conversationId)",
|
||||
this.sessionLabel);
|
||||
// Index for index on Message nodes
|
||||
String messageIndexCypher = String
|
||||
.format("CREATE INDEX message_index_index IF NOT EXISTS FOR (n:%s) ON (n.index)", this.messageLabel);
|
||||
session.run(sessionIndexCypher);
|
||||
session.run(messageIndexCypher);
|
||||
logger.info("Ensured Neo4j indexes for conversationId and message index.");
|
||||
}
|
||||
catch (Exception e) {
|
||||
logger.warn("Failed to ensure Neo4j indexes for chat memory: {}", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
package org.springframework.ai.chat.memory.neo4j;
|
||||
|
||||
import org.neo4j.driver.Result;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.neo4j.driver.Transaction;
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.content.MediaContent;
|
||||
import org.springframework.util.MimeType;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* An implementation of {@link ChatMemoryRepository} for Neo4J
|
||||
*
|
||||
* @author Enrico Rampazzo
|
||||
* @since 1.0.0
|
||||
*/
|
||||
|
||||
public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
private final Neo4jChatMemoryConfig config;
|
||||
|
||||
public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> findConversationIds() {
|
||||
try (var session = config.getDriver().session()) {
|
||||
return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel()))
|
||||
.stream()
|
||||
.map(r -> r.get("conversation.id").asString())
|
||||
.toList();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> findByConversationId(String conversationId) {
|
||||
String statementBuilder = """
|
||||
MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s)
|
||||
WITH m
|
||||
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
|
||||
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC
|
||||
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC
|
||||
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
|
||||
WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC
|
||||
RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias
|
||||
ORDER BY m.idx ASC
|
||||
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
|
||||
this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(),
|
||||
this.config.getToolCallLabel());
|
||||
Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId));
|
||||
return res.stream().map(record -> {
|
||||
Map<String, Object> messageMap = record.get("m").asMap();
|
||||
String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString();
|
||||
Message message = null;
|
||||
List<Media> mediaList = List.of();
|
||||
if (!record.get("medias").isNull()) {
|
||||
mediaList = getMedia(record);
|
||||
}
|
||||
if (msgType.equals(MessageType.USER.getValue())) {
|
||||
message = buildUserMessage(record, messageMap, mediaList);
|
||||
}
|
||||
if (msgType.equals(MessageType.ASSISTANT.getValue())) {
|
||||
message = buildAssistantMessage(record, messageMap, mediaList);
|
||||
}
|
||||
if (msgType.equals(MessageType.SYSTEM.getValue())) {
|
||||
SystemMessage.Builder systemMessageBuilder = SystemMessage.builder()
|
||||
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString());
|
||||
if (!record.get("metadata").isNull()) {
|
||||
Map<String, Object> retrievedMetadata = record.get("metadata").asMap();
|
||||
systemMessageBuilder.metadata(retrievedMetadata);
|
||||
}
|
||||
message = systemMessageBuilder.build();
|
||||
}
|
||||
if (msgType.equals(MessageType.TOOL.getValue())) {
|
||||
message = buildToolMessage(record);
|
||||
}
|
||||
if (message == null) {
|
||||
throw new IllegalArgumentException("%s messages are not supported"
|
||||
.formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString()));
|
||||
}
|
||||
message.getMetadata().put("messageType", message.getMessageType());
|
||||
return message;
|
||||
}).toList();
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveAll(String conversationId, List<Message> messages) {
|
||||
// First delete existing messages for this conversation
|
||||
deleteByConversationId(conversationId);
|
||||
|
||||
// Then add the new messages
|
||||
try (Session s = this.config.getDriver().session()) {
|
||||
try (Transaction t = s.beginTransaction()) {
|
||||
for (Message m : messages) {
|
||||
addMessageToTransaction(t, conversationId, m);
|
||||
}
|
||||
t.commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteByConversationId(String conversationId) {
|
||||
// First delete all messages and related nodes
|
||||
String deleteMessagesStatement = """
|
||||
MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s)
|
||||
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
|
||||
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s)
|
||||
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s)
|
||||
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
|
||||
DETACH DELETE m, metadata, media, tr, tc
|
||||
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
|
||||
this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(),
|
||||
this.config.getToolCallLabel());
|
||||
|
||||
// Then delete the conversation node itself
|
||||
String deleteConversationStatement = """
|
||||
MATCH (s:%s {id:$conversationId})
|
||||
DETACH DELETE s
|
||||
""".formatted(this.config.getSessionLabel());
|
||||
|
||||
try (Session s = this.config.getDriver().session()) {
|
||||
try (Transaction t = s.beginTransaction()) {
|
||||
// First delete messages
|
||||
t.run(deleteMessagesStatement, Map.of("conversationId", conversationId));
|
||||
// Then delete the conversation node
|
||||
t.run(deleteConversationStatement, Map.of("conversationId", conversationId));
|
||||
t.commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public Neo4jChatMemoryConfig getConfig() {
|
||||
return this.config;
|
||||
}
|
||||
|
||||
private Message buildToolMessage(org.neo4j.driver.Record record) {
|
||||
Message message;
|
||||
message = new ToolResponseMessage(record.get("toolResponses").asList(v -> {
|
||||
Map<String, Object> trMap = v.asMap();
|
||||
return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
|
||||
(String) trMap.get(ToolResponseAttributes.NAME.getValue()),
|
||||
(String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue()));
|
||||
}), record.get("metadata").asMap());
|
||||
return message;
|
||||
}
|
||||
|
||||
private Message buildAssistantMessage(org.neo4j.driver.Record record, Map<String, Object> messageMap,
|
||||
List<Media> mediaList) {
|
||||
Message message;
|
||||
message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(),
|
||||
record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> {
|
||||
var toolCallMap = v.asMap();
|
||||
return new AssistantMessage.ToolCall((String) toolCallMap.get("id"),
|
||||
(String) toolCallMap.get("type"), (String) toolCallMap.get("name"),
|
||||
(String) toolCallMap.get("arguments"));
|
||||
}), mediaList);
|
||||
return message;
|
||||
}
|
||||
|
||||
private Message buildUserMessage(org.neo4j.driver.Record record, Map<String, Object> messageMap,
|
||||
List<Media> mediaList) {
|
||||
Message message;
|
||||
Map<String, Object> metadata = record.get("metadata").asMap();
|
||||
message = UserMessage.builder()
|
||||
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString())
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.build();
|
||||
return message;
|
||||
}
|
||||
|
||||
private List<Media> getMedia(org.neo4j.driver.Record record) {
|
||||
List<Media> mediaList;
|
||||
mediaList = record.get("medias").asList(v -> {
|
||||
Map<String, Object> mediaMap = v.asMap();
|
||||
var mediaBuilder = Media.builder()
|
||||
.name((String) mediaMap.get(MediaAttributes.NAME.getValue()))
|
||||
.id(Optional.ofNullable(mediaMap.get(MediaAttributes.ID.getValue())).map(Object::toString).orElse(null))
|
||||
.mimeType(MimeType.valueOf(mediaMap.get(MediaAttributes.MIME_TYPE.getValue()).toString()));
|
||||
if (mediaMap.get(MediaAttributes.DATA.getValue()) instanceof String stringData) {
|
||||
mediaBuilder.data(URI.create(stringData));
|
||||
}
|
||||
else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) {
|
||||
mediaBuilder.data(mediaMap.get(MediaAttributes.DATA.getValue()));
|
||||
}
|
||||
return mediaBuilder.build();
|
||||
|
||||
});
|
||||
return mediaList;
|
||||
}
|
||||
|
||||
private void addMessageToTransaction(Transaction t, String conversationId, Message message) {
|
||||
Map<String, Object> queryParameters = new HashMap<>();
|
||||
queryParameters.put("conversationId", conversationId);
|
||||
StringBuilder statementBuilder = new StringBuilder();
|
||||
statementBuilder.append("""
|
||||
MERGE (s:%s {id:$conversationId}) WITH s
|
||||
OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s
|
||||
CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties
|
||||
SET msg.idx = totalMsg + 1
|
||||
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
|
||||
this.config.getMessageLabel()));
|
||||
Map<String, Object> attributes = new HashMap<>();
|
||||
|
||||
attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue());
|
||||
attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText());
|
||||
attributes.put("id", UUID.randomUUID().toString());
|
||||
queryParameters.put("messageProperties", attributes);
|
||||
|
||||
if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) {
|
||||
statementBuilder.append("""
|
||||
WITH msg
|
||||
CREATE (metadataNode:%s)
|
||||
CREATE (msg)-[:HAS_METADATA]->(metadataNode)
|
||||
SET metadataNode = $metadata
|
||||
""".formatted(this.config.getMetadataLabel()));
|
||||
Map<String, Object> metadataCopy = new HashMap<>(message.getMetadata());
|
||||
metadataCopy.remove("messageType");
|
||||
queryParameters.put("metadata", metadataCopy);
|
||||
}
|
||||
if (message instanceof AssistantMessage assistantMessage) {
|
||||
if (assistantMessage.hasToolCalls()) {
|
||||
statementBuilder.append("""
|
||||
WITH msg
|
||||
FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc
|
||||
CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall))
|
||||
""".formatted(this.config.getToolCallLabel()));
|
||||
List<Map<String, Object>> toolCallMaps = new ArrayList<>();
|
||||
for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) {
|
||||
AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i);
|
||||
toolCallMaps
|
||||
.add(Map.of(ToolCallAttributes.ID.getValue(), tc.id(), ToolCallAttributes.NAME.getValue(),
|
||||
tc.name(), ToolCallAttributes.ARGUMENTS.getValue(), tc.arguments(),
|
||||
ToolCallAttributes.TYPE.getValue(), tc.type(), ToolCallAttributes.IDX.getValue(), i));
|
||||
}
|
||||
queryParameters.put("toolCalls", toolCallMaps);
|
||||
}
|
||||
}
|
||||
if (message instanceof ToolResponseMessage toolResponseMessage) {
|
||||
List<ToolResponseMessage.ToolResponse> toolResponses = toolResponseMessage.getResponses();
|
||||
List<Map<String, String>> toolResponseMaps = new ArrayList<>();
|
||||
for (int i = 0; i < Optional.ofNullable(toolResponses).orElse(List.of()).size(); i++) {
|
||||
var toolResponse = toolResponses.get(i);
|
||||
Map<String, String> toolResponseMap = Map.of(ToolResponseAttributes.ID.getValue(), toolResponse.id(),
|
||||
ToolResponseAttributes.NAME.getValue(), toolResponse.name(),
|
||||
ToolResponseAttributes.RESPONSE_DATA.getValue(), toolResponse.responseData(),
|
||||
ToolResponseAttributes.IDX.getValue(), Integer.toString(i));
|
||||
toolResponseMaps.add(toolResponseMap);
|
||||
}
|
||||
statementBuilder.append("""
|
||||
WITH msg
|
||||
FOREACH(tr IN $toolResponses | CREATE (tm:%s)
|
||||
SET tm = tr
|
||||
MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm))
|
||||
""".formatted(this.config.getToolResponseLabel()));
|
||||
queryParameters.put("toolResponses", toolResponseMaps);
|
||||
}
|
||||
if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) {
|
||||
List<Map<String, Object>> mediaNodes = convertMediaToMap(messageWithMedia.getMedia());
|
||||
statementBuilder.append("""
|
||||
WITH msg
|
||||
UNWIND $media AS m
|
||||
CREATE (media:%s) SET media = m
|
||||
WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media)
|
||||
""".formatted(this.config.getMediaLabel()));
|
||||
queryParameters.put("media", mediaNodes);
|
||||
}
|
||||
t.run(statementBuilder.toString(), queryParameters);
|
||||
}
|
||||
|
||||
private List<Map<String, Object>> convertMediaToMap(List<Media> media) {
|
||||
List<Map<String, Object>> mediaMaps = new ArrayList<>();
|
||||
for (int i = 0; i < media.size(); i++) {
|
||||
Map<String, Object> mediaMap = new HashMap<>();
|
||||
Media m = media.get(i);
|
||||
mediaMap.put(MediaAttributes.ID.getValue(), m.getId());
|
||||
mediaMap.put(MediaAttributes.MIME_TYPE.getValue(), m.getMimeType().toString());
|
||||
mediaMap.put(MediaAttributes.NAME.getValue(), m.getName());
|
||||
mediaMap.put(MediaAttributes.DATA.getValue(), m.getData());
|
||||
mediaMap.put(MediaAttributes.IDX.getValue(), i);
|
||||
mediaMaps.add(mediaMap);
|
||||
}
|
||||
return mediaMaps;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package org.springframework.ai.chat.memory.neo4j;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.neo4j.driver.Driver;
|
||||
import org.neo4j.driver.GraphDatabase;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.neo4j.driver.Result;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Testcontainers
|
||||
class Neo4jChatMemoryConfigIT {
|
||||
|
||||
@Container
|
||||
static final Neo4jContainer<?> neo4jContainer = new Neo4jContainer<>("neo4j:5").withoutAuthentication();
|
||||
|
||||
static Driver driver;
|
||||
|
||||
@BeforeAll
|
||||
static void setupDriver() {
|
||||
driver = GraphDatabase.driver(neo4jContainer.getBoltUrl());
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void closeDriver() {
|
||||
if (driver != null)
|
||||
driver.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCreateRequiredIndexes() {
|
||||
// Given
|
||||
Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder().withDriver(driver).build();
|
||||
// When
|
||||
try (Session session = driver.session()) {
|
||||
Result result = session.run("SHOW INDEXES");
|
||||
boolean sessionIndexFound = false;
|
||||
boolean messageIndexFound = false;
|
||||
while (result.hasNext()) {
|
||||
var record = result.next();
|
||||
String name = record.get("name").asString();
|
||||
if ("session_conversation_id_index".equals(name))
|
||||
sessionIndexFound = true;
|
||||
if ("message_index_index".equals(name))
|
||||
messageIndexFound = true;
|
||||
}
|
||||
// Then
|
||||
assertThat(sessionIndexFound).isTrue();
|
||||
assertThat(messageIndexFound).isTrue();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void builderShouldSetCustomLabels() {
|
||||
String customSessionLabel = "ChatSession";
|
||||
String customMessageLabel = "ChatMessage";
|
||||
Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder()
|
||||
.withDriver(driver)
|
||||
.withSessionLabel(customSessionLabel)
|
||||
.withMessageLabel(customMessageLabel)
|
||||
.build();
|
||||
assertThat(config.getSessionLabel()).isEqualTo(customSessionLabel);
|
||||
assertThat(config.getMessageLabel()).isEqualTo(customMessageLabel);
|
||||
}
|
||||
|
||||
@Test
|
||||
void gettersShouldReturnConfiguredValues() {
|
||||
Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder()
|
||||
.withDriver(driver)
|
||||
.withSessionLabel("Session")
|
||||
.withToolCallLabel("ToolCall")
|
||||
.withMetadataLabel("Metadata")
|
||||
.withMessageLabel("Message")
|
||||
.withToolResponseLabel("ToolResponse")
|
||||
.withMediaLabel("Media")
|
||||
.build();
|
||||
assertThat(config.getSessionLabel()).isEqualTo("Session");
|
||||
assertThat(config.getToolCallLabel()).isEqualTo("ToolCall");
|
||||
assertThat(config.getMetadataLabel()).isEqualTo("Metadata");
|
||||
assertThat(config.getMessageLabel()).isEqualTo("Message");
|
||||
assertThat(config.getToolResponseLabel()).isEqualTo("ToolResponse");
|
||||
assertThat(config.getMediaLabel()).isEqualTo("Media");
|
||||
assertThat(config.getDriver()).isNotNull();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,420 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.memory.neo4j;
|
||||
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.neo4j.driver.Driver;
|
||||
import org.neo4j.driver.Result;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.util.MimeType;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Integration tests for {@link Neo4jChatMemoryRepository}.
|
||||
*
|
||||
* @author Enrico Rampazzo
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@Testcontainers
|
||||
class Neo4jChatMemoryRepositoryIT {
|
||||
|
||||
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");
|
||||
|
||||
@SuppressWarnings({ "rawtypes", "resource" })
|
||||
@Container
|
||||
static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5"))
|
||||
.withoutAuthentication()
|
||||
.withExposedPorts(7474, 7687);
|
||||
|
||||
private ChatMemoryRepository chatMemoryRepository;
|
||||
|
||||
private Driver driver;
|
||||
|
||||
private Neo4jChatMemoryConfig config;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
|
||||
config = Neo4jChatMemoryConfig.builder().withDriver(driver).build();
|
||||
chatMemoryRepository = new Neo4jChatMemoryRepository(config);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
// Clean up all data after each test
|
||||
try (Session session = driver.session()) {
|
||||
session.run("MATCH (n) DETACH DELETE n");
|
||||
}
|
||||
driver.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
void correctChatMemoryRepositoryInstance() {
|
||||
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
|
||||
assertThat(chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM",
|
||||
"Message from tool,TOOL" })
|
||||
void saveAndFindSingleMessage(String content, MessageType messageType) {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
Message message = createMessageByType(content + " - " + conversationId, messageType);
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(message));
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
Message retrievedMessage = retrievedMessages.get(0);
|
||||
assertThat(retrievedMessage.getMessageType()).isEqualTo(messageType);
|
||||
|
||||
if (messageType != MessageType.TOOL) {
|
||||
assertThat(retrievedMessage.getText()).isEqualTo(message.getText());
|
||||
}
|
||||
|
||||
// Verify directly in the database
|
||||
try (Session session = driver.session()) {
|
||||
var result = session.run(
|
||||
"MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count"
|
||||
.formatted(config.getSessionLabel(), config.getMessageLabel()),
|
||||
Map.of("conversationId", conversationId));
|
||||
assertThat(result.single().get("count").asLong()).isEqualTo(1);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveAndFindMultipleMessages() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
List<Message> messages = List.of(new AssistantMessage("Message from assistant - " + conversationId),
|
||||
new UserMessage("Message from user - " + conversationId),
|
||||
new SystemMessage("Message from system - " + conversationId),
|
||||
new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(messages.size());
|
||||
|
||||
// Verify the order is preserved (ascending by index)
|
||||
for (int i = 0; i < messages.size(); i++) {
|
||||
if (messages.get(i).getMessageType() != MessageType.TOOL) {
|
||||
assertThat(retrievedMessages.get(i).getText()).isEqualTo(messages.get(i).getText());
|
||||
}
|
||||
assertThat(retrievedMessages.get(i).getMessageType()).isEqualTo(messages.get(i).getMessageType());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void verifyMessageOrdering() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
List<Message> messages = new ArrayList<>();
|
||||
|
||||
// Add messages in a specific order
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
messages.add(new UserMessage("Message " + i));
|
||||
}
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(messages.size());
|
||||
|
||||
// Verify that messages are returned in ascending order (oldest first)
|
||||
for (int i = 0; i < messages.size(); i++) {
|
||||
assertThat(retrievedMessages.get(i).getText()).isEqualTo("Message " + (i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void findConversationIds() {
|
||||
// Create multiple conversations
|
||||
var conversationId1 = UUID.randomUUID().toString();
|
||||
var conversationId2 = UUID.randomUUID().toString();
|
||||
var conversationId3 = UUID.randomUUID().toString();
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId1, List.<Message>of(new UserMessage("Message for conversation 1")));
|
||||
chatMemoryRepository.saveAll(conversationId2, List.<Message>of(new UserMessage("Message for conversation 2")));
|
||||
chatMemoryRepository.saveAll(conversationId3, List.<Message>of(new UserMessage("Message for conversation 3")));
|
||||
|
||||
List<String> conversationIds = chatMemoryRepository.findConversationIds();
|
||||
|
||||
assertThat(conversationIds).hasSize(3);
|
||||
assertThat(conversationIds).contains(conversationId1, conversationId2, conversationId3);
|
||||
}
|
||||
|
||||
@Test
|
||||
void deleteByConversationId() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
List<Message> messages = List.of(new AssistantMessage("Message from assistant"),
|
||||
new UserMessage("Message from user"), new SystemMessage("Message from system"));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
|
||||
// Verify messages were saved
|
||||
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
|
||||
|
||||
// Delete the conversation
|
||||
chatMemoryRepository.deleteByConversationId(conversationId);
|
||||
|
||||
// Verify messages were deleted
|
||||
assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty();
|
||||
|
||||
// Verify directly in the database
|
||||
try (Session session = driver.session()) {
|
||||
var result = session.run(
|
||||
"MATCH (s:%s {id:$conversationId}) RETURN count(s) as count".formatted(config.getSessionLabel()),
|
||||
Map.of("conversationId", conversationId));
|
||||
assertThat(result.single().get("count").asLong()).isZero();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveAllReplacesExistingMessages() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
// Save initial messages
|
||||
List<Message> initialMessages = List.of(new UserMessage("Initial message 1"),
|
||||
new UserMessage("Initial message 2"), new UserMessage("Initial message 3"));
|
||||
chatMemoryRepository.saveAll(conversationId, initialMessages);
|
||||
|
||||
// Verify initial messages were saved
|
||||
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
|
||||
|
||||
// Replace with new messages
|
||||
List<Message> newMessages = List.of(new UserMessage("New message 1"), new UserMessage("New message 2"));
|
||||
chatMemoryRepository.saveAll(conversationId, newMessages);
|
||||
|
||||
// Verify only new messages exist
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(2);
|
||||
assertThat(retrievedMessages.get(0).getText()).isEqualTo("New message 1");
|
||||
assertThat(retrievedMessages.get(1).getText()).isEqualTo("New message 2");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleMediaContent() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
MimeType textPlain = MimeType.valueOf("text/plain");
|
||||
List<Media> media = List.of(Media.builder()
|
||||
.name("some media")
|
||||
.id(UUID.randomUUID().toString())
|
||||
.mimeType(textPlain)
|
||||
.data("hello".getBytes(StandardCharsets.UTF_8))
|
||||
.build(), Media.builder().data(URI.create("http://www.example.com")).mimeType(textPlain).build());
|
||||
|
||||
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(userMessageWithMedia));
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
UserMessage retrievedMessage = (UserMessage) retrievedMessages.get(0);
|
||||
assertThat(retrievedMessage.getMedia()).hasSize(2);
|
||||
assertThat(retrievedMessage.getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleAssistantMessageWithToolCalls() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
AssistantMessage assistantMessage = new AssistantMessage("Message with tool calls", Map.of(),
|
||||
List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"),
|
||||
new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2")));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
AssistantMessage retrievedMessage = (AssistantMessage) retrievedMessages.get(0);
|
||||
assertThat(retrievedMessage.getToolCalls()).hasSize(2);
|
||||
assertThat(retrievedMessage.getToolCalls().get(0).id()).isEqualTo("id1");
|
||||
assertThat(retrievedMessage.getToolCalls().get(1).id()).isEqualTo("id2");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleToolResponseMessage() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List
|
||||
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
|
||||
Map.of("metadataKey", "metadataValue"));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
ToolResponseMessage retrievedMessage = (ToolResponseMessage) retrievedMessages.get(0);
|
||||
assertThat(retrievedMessage.getResponses()).hasSize(2);
|
||||
assertThat(retrievedMessage.getResponses().get(0).id()).isEqualTo("id1");
|
||||
assertThat(retrievedMessage.getResponses().get(1).id()).isEqualTo("id2");
|
||||
assertThat(retrievedMessage.getMetadata()).containsEntry("metadataKey", "metadataValue");
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveAndFindSystemMessageWithMetadata() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
Map<String, Object> customMetadata = Map.of("priority", "high", "source", "test");
|
||||
|
||||
SystemMessage systemMessage = SystemMessage.builder()
|
||||
.text("System message with custom metadata - " + conversationId)
|
||||
.metadata(customMetadata)
|
||||
.build();
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.of(systemMessage));
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
Message retrievedMessage = retrievedMessages.get(0);
|
||||
|
||||
assertThat(retrievedMessage).isInstanceOf(SystemMessage.class);
|
||||
assertThat(retrievedMessage.getText()).isEqualTo("System message with custom metadata - " + conversationId);
|
||||
// Crucial assertion for the metadata
|
||||
assertThat(retrievedMessage.getMetadata()).containsAllEntriesOf(customMetadata);
|
||||
// Also check that the 'messageType' key is present (added by the repository)
|
||||
assertThat(retrievedMessage.getMetadata()).containsEntry("messageType", MessageType.SYSTEM);
|
||||
// Verify no extra unwanted metadata keys beyond what's expected
|
||||
assertThat(retrievedMessage.getMetadata().keySet())
|
||||
.containsExactlyInAnyOrderElementsOf(new ArrayList<>(customMetadata.keySet()) {
|
||||
{
|
||||
add("messageType");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveAllWithEmptyListClearsConversation() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
// 1. Setup: Create a conversation with some initial messages
|
||||
UserMessage initialMessage1 = new UserMessage("Initial message 1");
|
||||
AssistantMessage initialMessage2 = new AssistantMessage("Initial response 1");
|
||||
chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2));
|
||||
|
||||
// Verify initial messages are there
|
||||
List<Message> messagesAfterInitialSave = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(messagesAfterInitialSave).hasSize(2);
|
||||
|
||||
// 2. Action: Call saveAll with an empty list
|
||||
chatMemoryRepository.saveAll(conversationId, Collections.emptyList());
|
||||
|
||||
// 3. Assertions:
|
||||
// a) No messages should be found for the conversationId
|
||||
List<Message> messagesAfterEmptySave = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(messagesAfterEmptySave).isEmpty();
|
||||
|
||||
// b) The conversationId itself should no longer be listed (because
|
||||
// deleteByConversationId removes the session node)
|
||||
List<String> conversationIds = chatMemoryRepository.findConversationIds();
|
||||
assertThat(conversationIds).doesNotContain(conversationId);
|
||||
|
||||
// c) Verify directly in Neo4j that the conversation node is gone
|
||||
try (Session session = driver.session()) {
|
||||
Result result = session.run(
|
||||
"MATCH (s:%s {id: $conversationId}) RETURN s".formatted(config.getSessionLabel()),
|
||||
Map.of("conversationId", conversationId));
|
||||
assertThat(result.hasNext()).isFalse(); // No conversation node should exist
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void saveAndFindMessagesWithEmptyContentOrMetadata() {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
UserMessage messageWithEmptyContent = new UserMessage("");
|
||||
UserMessage messageWithEmptyMetadata = UserMessage.builder()
|
||||
.text("Content with empty metadata")
|
||||
.metadata(Collections.emptyMap())
|
||||
.build();
|
||||
|
||||
List<Message> messagesToSave = List.of(messageWithEmptyContent, messageWithEmptyMetadata);
|
||||
chatMemoryRepository.saveAll(conversationId, messagesToSave);
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(2);
|
||||
|
||||
// Verify first message (empty content)
|
||||
Message retrievedEmptyContentMsg = retrievedMessages.get(0);
|
||||
assertThat(retrievedEmptyContentMsg).isInstanceOf(UserMessage.class);
|
||||
assertThat(retrievedEmptyContentMsg.getText()).isEqualTo("");
|
||||
assertThat(retrievedEmptyContentMsg.getMetadata()).containsEntry("messageType", MessageType.USER); // Default
|
||||
// metadata
|
||||
assertThat(retrievedEmptyContentMsg.getMetadata().keySet()).hasSize(1); // Only
|
||||
// messageType
|
||||
|
||||
// Verify second message (empty metadata from input, should only have messageType
|
||||
// after retrieval)
|
||||
Message retrievedEmptyMetadataMsg = retrievedMessages.get(1);
|
||||
assertThat(retrievedEmptyMetadataMsg).isInstanceOf(UserMessage.class);
|
||||
assertThat(retrievedEmptyMetadataMsg.getText()).isEqualTo("Content with empty metadata");
|
||||
assertThat(retrievedEmptyMetadataMsg.getMetadata()).containsEntry("messageType", MessageType.USER);
|
||||
assertThat(retrievedEmptyMetadataMsg.getMetadata().keySet()).hasSize(1); // Only
|
||||
// messageType
|
||||
}
|
||||
|
||||
private Message createMessageByType(String content, MessageType messageType) {
|
||||
return switch (messageType) {
|
||||
case ASSISTANT -> new AssistantMessage(content);
|
||||
case USER -> new UserMessage(content);
|
||||
case SYSTEM -> new SystemMessage(content);
|
||||
case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")));
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory for creating Neo4j Driver instances.
|
||||
*/
|
||||
private static class Neo4jDriverFactory {
|
||||
|
||||
static Driver create(String boltUrl) {
|
||||
return org.neo4j.driver.GraphDatabase.driver(boltUrl);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -64,7 +64,7 @@ If you'd rather create the `InMemoryChatMemoryRepository` manually, you can do s
|
||||
ChatMemoryRepository repository = new InMemoryChatMemoryRepository();
|
||||
----
|
||||
|
||||
=== JDBC Repository
|
||||
=== JdbcChatMemoryRepository
|
||||
|
||||
`JdbcChatMemoryRepository` is a built-in implementation that uses JDBC to store messages in a relational database. It is suitable for applications that require persistent storage of chat memory.
|
||||
|
||||
@@ -127,6 +127,7 @@ ChatMemory chatMemory = MessageWindowChatMemory.builder()
|
||||
| `spring.ai.chat.memory.repository.jdbc.initialize-schema` | Whether to initialize the schema on startup. | `true`
|
||||
|===
|
||||
|
||||
|
||||
==== Schema Initialization
|
||||
|
||||
The auto-configuration will automatically create the `ai_chat_memory` table using the JDBC driver. Currently, only PostgreSQL and MariaDB are supported.
|
||||
@@ -135,6 +136,78 @@ You can disable the schema initialization by setting the property `spring.ai.cha
|
||||
|
||||
If your project uses a tool like Flyway or Liquibase to manage your database schemas, you can disable the schema initialization and refer to link:https://github.com/spring-projects/spring-ai/tree/main/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc[these SQL scripts] for configuring those tools to create the `ai_chat_memory` table.
|
||||
|
||||
=== Neo4j ChatMemoryRepository
|
||||
|
||||
`Neo4jChatMemoryRepository` is a built-in implementation that uses Neo4j to store chat messages as nodes and relationships in a property graph database. It is suitable for applications that want to leverage Neo4j's graph capabilities for chat memory persistence.
|
||||
|
||||
First, add the following dependency to your project:
|
||||
|
||||
[tabs]
|
||||
======
|
||||
Maven::
|
||||
+
|
||||
[source, xml]
|
||||
----
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-starter-model-chat-memory-neo4j</artifactId>
|
||||
</dependency>
|
||||
----
|
||||
|
||||
Gradle::
|
||||
+
|
||||
[source,groovy]
|
||||
----
|
||||
dependencies {
|
||||
implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-neo4j'
|
||||
}
|
||||
----
|
||||
======
|
||||
|
||||
Spring AI provides auto-configuration for the `Neo4jChatMemoryRepository`, which you can use directly in your application.
|
||||
|
||||
[source,java]
|
||||
----
|
||||
@Autowired
|
||||
Neo4jChatMemoryRepository chatMemoryRepository;
|
||||
|
||||
ChatMemory chatMemory = MessageWindowChatMemory.builder()
|
||||
.chatMemoryRepository(chatMemoryRepository)
|
||||
.maxMessages(10)
|
||||
.build();
|
||||
----
|
||||
|
||||
If you'd rather create the `Neo4jChatMemoryRepository` manually, you can do so by providing a Neo4j `Driver` instance:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
ChatMemoryRepository chatMemoryRepository = Neo4jChatMemoryRepository.builder()
|
||||
.driver(driver)
|
||||
.build();
|
||||
|
||||
ChatMemory chatMemory = MessageWindowChatMemory.builder()
|
||||
.chatMemoryRepository(chatMemoryRepository)
|
||||
.maxMessages(10)
|
||||
.build();
|
||||
----
|
||||
|
||||
==== Configuration Properties
|
||||
|
||||
[cols="2,5,1",stripes=even]
|
||||
|===
|
||||
|Property | Description | Default Value
|
||||
| `spring.ai.chat.memory.neo4j.sessionLabel` | The label for the nodes that store conversation sessions | `Session`
|
||||
| `spring.ai.chat.memory.neo4j.messageLabel` | The label for the nodes that store messages | `Message`
|
||||
| `spring.ai.chat.memory.neo4j.toolCallLabel` | The label for nodes that store tool calls (e.g. in Assistant Messages) | `ToolCall`
|
||||
| `spring.ai.chat.memory.neo4j.metadataLabel` | The label for nodes that store message metadata | `Metadata`
|
||||
| `spring.ai.chat.memory.neo4j.toolResponseLabel` | The label for the nodes that store tool responses | `ToolResponse`
|
||||
| `spring.ai.chat.memory.neo4j.mediaLabel` | The label for the nodes that store media associated with a message | `Media`
|
||||
|===
|
||||
|
||||
==== Index Initialization
|
||||
|
||||
The Neo4j repository will automatically ensure that indexes are created for conversation IDs and message indices to optimize performance. If you use custom labels, indexes will be created for those labels as well. No schema initialization is required, but you should ensure your Neo4j instance is accessible to your application.
|
||||
|
||||
== Memory in Chat Client
|
||||
|
||||
When using the ChatClient API, you can provide a `ChatMemory` implementation to maintain conversation context across multiple interactions.
|
||||
|
||||
Reference in New Issue
Block a user