Enhance Neo4jChatMemoryRepository, expand integration tests, improve configuration, and update documentation

- Enhanced Neo4jChatMemoryRepository to correctly restore custom metadata for SystemMessage using SystemMessage.Builder.
- Refactored and clarified Neo4jChatMemoryRepository implementation code.
- Added comprehensive integration tests for Neo4jChatMemoryConfig and Neo4jChatMemoryRepository, including:
-- Index creation verification
-- Custom label support
-- Getter validation for all configuration properties
-- Tests for saving and retrieving SystemMessage metadata
-- Tests ensuring saveAll(conversationId, Collections.emptyList()) clears all messages and removes the conversation node
-- Tests for handling of messages with empty content and empty metadata
-- Improved overall test coverage for Neo4j persistence and configuration edge cases
- Fixed resource management bugs in test classes (ensured proper driver/session closure).
- Improved index creation logic in Neo4jChatMemoryConfig for reliability and logging.
- Updated documentation to include Neo4jChatMemoryRepository usage and configuration

Signed-off-by: enricorampazzo <enrico.rampazzo@live.com>
Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
This commit is contained in:
enricorampazzo
2025-05-06 15:45:53 +04:00
committed by Mark Pollack
parent 4ec067641b
commit 43aa8939d2
11 changed files with 1113 additions and 28 deletions

View File

@@ -0,0 +1,67 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.chat.memory.jdbc.autoconfigure;
import javax.sql.DataSource;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Jonathan Leijendekker
*/
@Testcontainers
class JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests {
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17");
@Container
@SuppressWarnings("resource")
static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME)
.withDatabaseName("chat_memory_initializer_test")
.withUsername("postgres")
.withPassword("postgres");
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class,
JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class))
.withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()),
String.format("spring.datasource.username=%s", postgresContainer.getUsername()),
String.format("spring.datasource.password=%s", postgresContainer.getPassword()));
@Test
void getSettings_shouldHaveSchemaLocations() {
this.contextRunner.run(context -> {
var dataSource = context.getBean(DataSource.class);
var settings = JdbcChatMemoryDataSourceScriptDatabaseInitializer.getSettings(dataSource);
assertThat(settings.getSchemaLocations())
.containsOnly("classpath:org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
});
}
}

View File

@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.neo4j.autoconfigure;
import org.neo4j.driver.Driver;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -29,19 +29,19 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
/**
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}.
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemoryRepository}.
*
* @author Enrico Rampazzo
* @since 1.0.0
*/
@AutoConfiguration(after = Neo4jAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class)
@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class })
@ConditionalOnClass({ Neo4jChatMemoryRepository.class, Driver.class })
@EnableConfigurationProperties(Neo4jChatMemoryProperties.class)
public class Neo4jChatMemoryAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) {
public Neo4jChatMemoryRepository chatMemoryRepository(Neo4jChatMemoryProperties properties, Driver driver) {
var builder = Neo4jChatMemoryConfig.builder()
.withMediaLabel(properties.getMediaLabel())
@@ -52,7 +52,7 @@ public class Neo4jChatMemoryAutoConfiguration {
.withToolResponseLabel(properties.getToolResponseLabel())
.withDriver(driver);
return Neo4jChatMemory.create(builder.build());
return new Neo4jChatMemoryRepository(builder.build());
}
}

View File

@@ -23,12 +23,14 @@ import java.util.Map;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
@@ -51,7 +53,7 @@ import static org.assertj.core.api.Assertions.assertThat;
* @since 1.0.0
*/
@Testcontainers
class Neo4jChatMemoryAutoConfigurationIT {
class Neo4jChatMemoryRepositoryAutoConfigurationIT {
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");
@@ -67,31 +69,31 @@ class Neo4jChatMemoryAutoConfigurationIT {
@Test
void addAndGet() {
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()).run(context -> {
Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class);
ChatMemoryRepository memory = context.getBean(ChatMemoryRepository.class);
String sessionId = UUID.randomUUID().toString();
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
assertThat(memory.findByConversationId(sessionId)).isEmpty();
UserMessage userMessage = new UserMessage("test question");
memory.add(sessionId, userMessage);
List<Message> messages = memory.get(sessionId, Integer.MAX_VALUE);
memory.saveAll(sessionId, List.of(userMessage));
List<Message> messages = memory.findByConversationId(sessionId);
assertThat(messages).hasSize(1);
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage);
memory.clear(sessionId);
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
memory.deleteByConversationId(sessionId);
assertThat(memory.findByConversationId(sessionId)).isEmpty();
AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(),
List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments")));
memory.add(sessionId, List.of(userMessage, assistantMessage));
messages = memory.get(sessionId, Integer.MAX_VALUE);
memory.saveAll(sessionId, List.of(userMessage, assistantMessage));
messages = memory.findByConversationId(sessionId);
assertThat(messages).hasSize(2);
assertThat(messages.get(1)).isEqualTo(userMessage);
assertThat(messages.get(0)).isEqualTo(userMessage);
assertThat(messages.get(0)).isEqualTo(assistantMessage);
memory.clear(sessionId);
assertThat(messages.get(1)).isEqualTo(assistantMessage);
memory.deleteByConversationId(sessionId);
MimeType textPlain = MimeType.valueOf("text/plain");
List<Media> media = List.of(
Media.builder()
@@ -102,28 +104,28 @@ class Neo4jChatMemoryAutoConfigurationIT {
.build(),
Media.builder().data(URI.create("http://www.google.com")).mimeType(textPlain).build());
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
memory.add(sessionId, userMessageWithMedia);
memory.saveAll(sessionId, List.of(userMessageWithMedia));
messages = memory.get(sessionId, Integer.MAX_VALUE);
messages = memory.findByConversationId(sessionId);
assertThat(messages.size()).isEqualTo(1);
assertThat(messages.get(0)).isEqualTo(userMessageWithMedia);
assertThat(((UserMessage) messages.get(0)).getMedia()).hasSize(2);
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator()
.isEqualTo(media);
memory.clear(sessionId);
memory.deleteByConversationId(sessionId);
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(
List.of(new ToolResponse("id", "name", "responseData"),
new ToolResponse("id2", "name2", "responseData2")),
Map.of("id", "id", "metadataKey", "metadata"));
memory.add(sessionId, toolResponseMessage);
messages = memory.get(sessionId, Integer.MAX_VALUE);
memory.saveAll(sessionId, List.of(toolResponseMessage));
messages = memory.findByConversationId(sessionId);
assertThat(messages.size()).isEqualTo(1);
assertThat(messages.get(0)).isEqualTo(toolResponseMessage);
memory.clear(sessionId);
memory.deleteByConversationId(sessionId);
SystemMessage sm = new SystemMessage("this is a System message");
memory.add(sessionId, sm);
messages = memory.get(sessionId, Integer.MAX_VALUE);
memory.saveAll(sessionId, List.of(sm));
messages = memory.findByConversationId(sessionId);
assertThat(messages).hasSize(1);
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm);
});
@@ -148,7 +150,7 @@ class Neo4jChatMemoryAutoConfigurationIT {
propertyBase.formatted("toolresponselabel", toolResponseLabel),
propertyBase.formatted("medialabel", mediaLabel))
.run(context -> {
Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class);
Neo4jChatMemoryRepository chatMemory = context.getBean(Neo4jChatMemoryRepository.class);
Neo4jChatMemoryConfig config = chatMemory.getConfig();
assertThat(config.getMessageLabel()).isEqualTo(messageLabel);
assertThat(config.getMediaLabel()).isEqualTo(mediaLabel);

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.embedding.observation.autoconfigure;
import io.micrometer.core.instrument.MeterRegistry;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.annotation.Bean;
/**
* Auto-configuration for Spring AI embedding model observations.
*
* @author Thomas Vitale
* @since 1.0.0
*/
@AutoConfiguration(
afterName = "org.springframework.boot.actuate.autoconfigure.observation.ObservationAutoConfiguration")
@ConditionalOnClass(EmbeddingModel.class)
public class EmbeddingObservationAutoConfiguration {
@Bean
@ConditionalOnMissingBean
@ConditionalOnBean(MeterRegistry.class)
EmbeddingModelMeterObservationHandler embeddingModelMeterObservationHandler(
ObjectProvider<MeterRegistry> meterRegistry) {
return new EmbeddingModelMeterObservationHandler(meterRegistry.getObject());
}
}

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.embedding.observation.autoconfigure;
import io.micrometer.core.instrument.composite.CompositeMeterRegistry;
import org.junit.jupiter.api.Test;
import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit tests for {@link EmbeddingObservationAutoConfiguration}.
*
* @author Thomas Vitale
*/
class EmbeddingObservationAutoConfigurationTests {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(EmbeddingObservationAutoConfiguration.class));
@Test
void meterObservationHandlerEnabled() {
this.contextRunner.withBean(CompositeMeterRegistry.class)
.run(context -> assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class));
}
@Test
void meterObservationHandlerDisabled() {
this.contextRunner
.run(context -> assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class));
}
}

View File

@@ -50,6 +50,7 @@
<artifactId>spring-data-neo4j</artifactId>
</dependency>
<!-- TESTING -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
@@ -68,6 +69,12 @@
<artifactId>spring-boot-testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.neo4j.driver</groupId>
@@ -79,6 +86,12 @@
<artifactId>neo4j</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -93,6 +93,33 @@ public final class Neo4jChatMemoryConfig {
this.toolCallLabel = builder.toolCallLabel;
this.metadataLabel = builder.metadataLabel;
this.toolResponseLabel = builder.toolResponseLabel;
ensureIndexes();
}
/**
* Ensures that indexes exist on conversationId for Session nodes and index for
* Message nodes. This improves query performance for lookups and ordering.
*/
private void ensureIndexes() {
if (this.driver == null) {
logger.warn("Neo4j Driver is null, cannot ensure indexes.");
return;
}
try (var session = this.driver.session()) {
// Index for conversationId on Session nodes
String sessionIndexCypher = String.format(
"CREATE INDEX session_conversation_id_index IF NOT EXISTS FOR (n:%s) ON (n.conversationId)",
this.sessionLabel);
// Index for index on Message nodes
String messageIndexCypher = String
.format("CREATE INDEX message_index_index IF NOT EXISTS FOR (n:%s) ON (n.index)", this.messageLabel);
session.run(sessionIndexCypher);
session.run(messageIndexCypher);
logger.info("Ensured Neo4j indexes for conversationId and message index.");
}
catch (Exception e) {
logger.warn("Failed to ensure Neo4j indexes for chat memory: {}", e.getMessage());
}
}
public static Builder builder() {

View File

@@ -0,0 +1,293 @@
package org.springframework.ai.chat.memory.neo4j;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.Transaction;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.content.Media;
import org.springframework.ai.content.MediaContent;
import org.springframework.util.MimeType;
import java.net.URI;
import java.util.*;
/**
* An implementation of {@link ChatMemoryRepository} for Neo4J
*
* @author Enrico Rampazzo
* @since 1.0.0
*/
public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
private final Neo4jChatMemoryConfig config;
public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) {
this.config = config;
}
@Override
public List<String> findConversationIds() {
try (var session = config.getDriver().session()) {
return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel()))
.stream()
.map(r -> r.get("conversation.id").asString())
.toList();
}
}
@Override
public List<Message> findByConversationId(String conversationId) {
String statementBuilder = """
MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s)
WITH m
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC
RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias
ORDER BY m.idx ASC
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(),
this.config.getToolCallLabel());
Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId));
return res.stream().map(record -> {
Map<String, Object> messageMap = record.get("m").asMap();
String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString();
Message message = null;
List<Media> mediaList = List.of();
if (!record.get("medias").isNull()) {
mediaList = getMedia(record);
}
if (msgType.equals(MessageType.USER.getValue())) {
message = buildUserMessage(record, messageMap, mediaList);
}
if (msgType.equals(MessageType.ASSISTANT.getValue())) {
message = buildAssistantMessage(record, messageMap, mediaList);
}
if (msgType.equals(MessageType.SYSTEM.getValue())) {
SystemMessage.Builder systemMessageBuilder = SystemMessage.builder()
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString());
if (!record.get("metadata").isNull()) {
Map<String, Object> retrievedMetadata = record.get("metadata").asMap();
systemMessageBuilder.metadata(retrievedMetadata);
}
message = systemMessageBuilder.build();
}
if (msgType.equals(MessageType.TOOL.getValue())) {
message = buildToolMessage(record);
}
if (message == null) {
throw new IllegalArgumentException("%s messages are not supported"
.formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString()));
}
message.getMetadata().put("messageType", message.getMessageType());
return message;
}).toList();
}
@Override
public void saveAll(String conversationId, List<Message> messages) {
// First delete existing messages for this conversation
deleteByConversationId(conversationId);
// Then add the new messages
try (Session s = this.config.getDriver().session()) {
try (Transaction t = s.beginTransaction()) {
for (Message m : messages) {
addMessageToTransaction(t, conversationId, m);
}
t.commit();
}
}
}
@Override
public void deleteByConversationId(String conversationId) {
// First delete all messages and related nodes
String deleteMessagesStatement = """
MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s)
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s)
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s)
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
DETACH DELETE m, metadata, media, tr, tc
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(),
this.config.getToolCallLabel());
// Then delete the conversation node itself
String deleteConversationStatement = """
MATCH (s:%s {id:$conversationId})
DETACH DELETE s
""".formatted(this.config.getSessionLabel());
try (Session s = this.config.getDriver().session()) {
try (Transaction t = s.beginTransaction()) {
// First delete messages
t.run(deleteMessagesStatement, Map.of("conversationId", conversationId));
// Then delete the conversation node
t.run(deleteConversationStatement, Map.of("conversationId", conversationId));
t.commit();
}
}
}
public Neo4jChatMemoryConfig getConfig() {
return this.config;
}
private Message buildToolMessage(org.neo4j.driver.Record record) {
Message message;
message = new ToolResponseMessage(record.get("toolResponses").asList(v -> {
Map<String, Object> trMap = v.asMap();
return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
(String) trMap.get(ToolResponseAttributes.NAME.getValue()),
(String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue()));
}), record.get("metadata").asMap());
return message;
}
private Message buildAssistantMessage(org.neo4j.driver.Record record, Map<String, Object> messageMap,
List<Media> mediaList) {
Message message;
message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(),
record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> {
var toolCallMap = v.asMap();
return new AssistantMessage.ToolCall((String) toolCallMap.get("id"),
(String) toolCallMap.get("type"), (String) toolCallMap.get("name"),
(String) toolCallMap.get("arguments"));
}), mediaList);
return message;
}
private Message buildUserMessage(org.neo4j.driver.Record record, Map<String, Object> messageMap,
List<Media> mediaList) {
Message message;
Map<String, Object> metadata = record.get("metadata").asMap();
message = UserMessage.builder()
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString())
.media(mediaList)
.metadata(metadata)
.build();
return message;
}
private List<Media> getMedia(org.neo4j.driver.Record record) {
List<Media> mediaList;
mediaList = record.get("medias").asList(v -> {
Map<String, Object> mediaMap = v.asMap();
var mediaBuilder = Media.builder()
.name((String) mediaMap.get(MediaAttributes.NAME.getValue()))
.id(Optional.ofNullable(mediaMap.get(MediaAttributes.ID.getValue())).map(Object::toString).orElse(null))
.mimeType(MimeType.valueOf(mediaMap.get(MediaAttributes.MIME_TYPE.getValue()).toString()));
if (mediaMap.get(MediaAttributes.DATA.getValue()) instanceof String stringData) {
mediaBuilder.data(URI.create(stringData));
}
else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) {
mediaBuilder.data(mediaMap.get(MediaAttributes.DATA.getValue()));
}
return mediaBuilder.build();
});
return mediaList;
}
private void addMessageToTransaction(Transaction t, String conversationId, Message message) {
Map<String, Object> queryParameters = new HashMap<>();
queryParameters.put("conversationId", conversationId);
StringBuilder statementBuilder = new StringBuilder();
statementBuilder.append("""
MERGE (s:%s {id:$conversationId}) WITH s
OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s
CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties
SET msg.idx = totalMsg + 1
""".formatted(this.config.getSessionLabel(), this.config.getMessageLabel(),
this.config.getMessageLabel()));
Map<String, Object> attributes = new HashMap<>();
attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue());
attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText());
attributes.put("id", UUID.randomUUID().toString());
queryParameters.put("messageProperties", attributes);
if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) {
statementBuilder.append("""
WITH msg
CREATE (metadataNode:%s)
CREATE (msg)-[:HAS_METADATA]->(metadataNode)
SET metadataNode = $metadata
""".formatted(this.config.getMetadataLabel()));
Map<String, Object> metadataCopy = new HashMap<>(message.getMetadata());
metadataCopy.remove("messageType");
queryParameters.put("metadata", metadataCopy);
}
if (message instanceof AssistantMessage assistantMessage) {
if (assistantMessage.hasToolCalls()) {
statementBuilder.append("""
WITH msg
FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc
CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall))
""".formatted(this.config.getToolCallLabel()));
List<Map<String, Object>> toolCallMaps = new ArrayList<>();
for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) {
AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i);
toolCallMaps
.add(Map.of(ToolCallAttributes.ID.getValue(), tc.id(), ToolCallAttributes.NAME.getValue(),
tc.name(), ToolCallAttributes.ARGUMENTS.getValue(), tc.arguments(),
ToolCallAttributes.TYPE.getValue(), tc.type(), ToolCallAttributes.IDX.getValue(), i));
}
queryParameters.put("toolCalls", toolCallMaps);
}
}
if (message instanceof ToolResponseMessage toolResponseMessage) {
List<ToolResponseMessage.ToolResponse> toolResponses = toolResponseMessage.getResponses();
List<Map<String, String>> toolResponseMaps = new ArrayList<>();
for (int i = 0; i < Optional.ofNullable(toolResponses).orElse(List.of()).size(); i++) {
var toolResponse = toolResponses.get(i);
Map<String, String> toolResponseMap = Map.of(ToolResponseAttributes.ID.getValue(), toolResponse.id(),
ToolResponseAttributes.NAME.getValue(), toolResponse.name(),
ToolResponseAttributes.RESPONSE_DATA.getValue(), toolResponse.responseData(),
ToolResponseAttributes.IDX.getValue(), Integer.toString(i));
toolResponseMaps.add(toolResponseMap);
}
statementBuilder.append("""
WITH msg
FOREACH(tr IN $toolResponses | CREATE (tm:%s)
SET tm = tr
MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm))
""".formatted(this.config.getToolResponseLabel()));
queryParameters.put("toolResponses", toolResponseMaps);
}
if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) {
List<Map<String, Object>> mediaNodes = convertMediaToMap(messageWithMedia.getMedia());
statementBuilder.append("""
WITH msg
UNWIND $media AS m
CREATE (media:%s) SET media = m
WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media)
""".formatted(this.config.getMediaLabel()));
queryParameters.put("media", mediaNodes);
}
t.run(statementBuilder.toString(), queryParameters);
}
private List<Map<String, Object>> convertMediaToMap(List<Media> media) {
List<Map<String, Object>> mediaMaps = new ArrayList<>();
for (int i = 0; i < media.size(); i++) {
Map<String, Object> mediaMap = new HashMap<>();
Media m = media.get(i);
mediaMap.put(MediaAttributes.ID.getValue(), m.getId());
mediaMap.put(MediaAttributes.MIME_TYPE.getValue(), m.getMimeType().toString());
mediaMap.put(MediaAttributes.NAME.getValue(), m.getName());
mediaMap.put(MediaAttributes.DATA.getValue(), m.getData());
mediaMap.put(MediaAttributes.IDX.getValue(), i);
mediaMaps.add(mediaMap);
}
return mediaMaps;
}
}

View File

@@ -0,0 +1,91 @@
package org.springframework.ai.chat.memory.neo4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.AfterAll;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Session;
import org.neo4j.driver.Result;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import static org.assertj.core.api.Assertions.assertThat;
@Testcontainers
class Neo4jChatMemoryConfigIT {
@Container
static final Neo4jContainer<?> neo4jContainer = new Neo4jContainer<>("neo4j:5").withoutAuthentication();
static Driver driver;
@BeforeAll
static void setupDriver() {
driver = GraphDatabase.driver(neo4jContainer.getBoltUrl());
}
@AfterAll
static void closeDriver() {
if (driver != null)
driver.close();
}
@Test
void shouldCreateRequiredIndexes() {
// Given
Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder().withDriver(driver).build();
// When
try (Session session = driver.session()) {
Result result = session.run("SHOW INDEXES");
boolean sessionIndexFound = false;
boolean messageIndexFound = false;
while (result.hasNext()) {
var record = result.next();
String name = record.get("name").asString();
if ("session_conversation_id_index".equals(name))
sessionIndexFound = true;
if ("message_index_index".equals(name))
messageIndexFound = true;
}
// Then
assertThat(sessionIndexFound).isTrue();
assertThat(messageIndexFound).isTrue();
}
}
@Test
void builderShouldSetCustomLabels() {
String customSessionLabel = "ChatSession";
String customMessageLabel = "ChatMessage";
Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder()
.withDriver(driver)
.withSessionLabel(customSessionLabel)
.withMessageLabel(customMessageLabel)
.build();
assertThat(config.getSessionLabel()).isEqualTo(customSessionLabel);
assertThat(config.getMessageLabel()).isEqualTo(customMessageLabel);
}
@Test
void gettersShouldReturnConfiguredValues() {
Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder()
.withDriver(driver)
.withSessionLabel("Session")
.withToolCallLabel("ToolCall")
.withMetadataLabel("Metadata")
.withMessageLabel("Message")
.withToolResponseLabel("ToolResponse")
.withMediaLabel("Media")
.build();
assertThat(config.getSessionLabel()).isEqualTo("Session");
assertThat(config.getToolCallLabel()).isEqualTo("ToolCall");
assertThat(config.getMetadataLabel()).isEqualTo("Metadata");
assertThat(config.getMessageLabel()).isEqualTo("Message");
assertThat(config.getToolResponseLabel()).isEqualTo("ToolResponse");
assertThat(config.getMediaLabel()).isEqualTo("Media");
assertThat(config.getDriver()).isNotNull();
}
}

View File

@@ -0,0 +1,420 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.memory.neo4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.content.Media;
import org.springframework.util.MimeType;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Integration tests for {@link Neo4jChatMemoryRepository}.
*
* @author Enrico Rampazzo
* @since 1.0.0
*/
@Testcontainers
class Neo4jChatMemoryRepositoryIT {
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");
@SuppressWarnings({ "rawtypes", "resource" })
@Container
static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5"))
.withoutAuthentication()
.withExposedPorts(7474, 7687);
private ChatMemoryRepository chatMemoryRepository;
private Driver driver;
private Neo4jChatMemoryConfig config;
@BeforeEach
void setUp() {
driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
config = Neo4jChatMemoryConfig.builder().withDriver(driver).build();
chatMemoryRepository = new Neo4jChatMemoryRepository(config);
}
@AfterEach
void tearDown() {
// Clean up all data after each test
try (Session session = driver.session()) {
session.run("MATCH (n) DETACH DELETE n");
}
driver.close();
}
@Test
void correctChatMemoryRepositoryInstance() {
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
assertThat(chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class);
}
@ParameterizedTest
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM",
"Message from tool,TOOL" })
void saveAndFindSingleMessage(String content, MessageType messageType) {
var conversationId = UUID.randomUUID().toString();
Message message = createMessageByType(content + " - " + conversationId, messageType);
chatMemoryRepository.saveAll(conversationId, List.<Message>of(message));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
Message retrievedMessage = retrievedMessages.get(0);
assertThat(retrievedMessage.getMessageType()).isEqualTo(messageType);
if (messageType != MessageType.TOOL) {
assertThat(retrievedMessage.getText()).isEqualTo(message.getText());
}
// Verify directly in the database
try (Session session = driver.session()) {
var result = session.run(
"MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count"
.formatted(config.getSessionLabel(), config.getMessageLabel()),
Map.of("conversationId", conversationId));
assertThat(result.single().get("count").asLong()).isEqualTo(1);
}
}
@Test
void saveAndFindMultipleMessages() {
var conversationId = UUID.randomUUID().toString();
List<Message> messages = List.of(new AssistantMessage("Message from assistant - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId),
new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))));
chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(messages.size());
// Verify the order is preserved (ascending by index)
for (int i = 0; i < messages.size(); i++) {
if (messages.get(i).getMessageType() != MessageType.TOOL) {
assertThat(retrievedMessages.get(i).getText()).isEqualTo(messages.get(i).getText());
}
assertThat(retrievedMessages.get(i).getMessageType()).isEqualTo(messages.get(i).getMessageType());
}
}
@Test
void verifyMessageOrdering() {
var conversationId = UUID.randomUUID().toString();
List<Message> messages = new ArrayList<>();
// Add messages in a specific order
for (int i = 1; i <= 5; i++) {
messages.add(new UserMessage("Message " + i));
}
chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(messages.size());
// Verify that messages are returned in ascending order (oldest first)
for (int i = 0; i < messages.size(); i++) {
assertThat(retrievedMessages.get(i).getText()).isEqualTo("Message " + (i + 1));
}
}
@Test
void findConversationIds() {
// Create multiple conversations
var conversationId1 = UUID.randomUUID().toString();
var conversationId2 = UUID.randomUUID().toString();
var conversationId3 = UUID.randomUUID().toString();
chatMemoryRepository.saveAll(conversationId1, List.<Message>of(new UserMessage("Message for conversation 1")));
chatMemoryRepository.saveAll(conversationId2, List.<Message>of(new UserMessage("Message for conversation 2")));
chatMemoryRepository.saveAll(conversationId3, List.<Message>of(new UserMessage("Message for conversation 3")));
List<String> conversationIds = chatMemoryRepository.findConversationIds();
assertThat(conversationIds).hasSize(3);
assertThat(conversationIds).contains(conversationId1, conversationId2, conversationId3);
}
@Test
void deleteByConversationId() {
var conversationId = UUID.randomUUID().toString();
List<Message> messages = List.of(new AssistantMessage("Message from assistant"),
new UserMessage("Message from user"), new SystemMessage("Message from system"));
chatMemoryRepository.saveAll(conversationId, messages);
// Verify messages were saved
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
// Delete the conversation
chatMemoryRepository.deleteByConversationId(conversationId);
// Verify messages were deleted
assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty();
// Verify directly in the database
try (Session session = driver.session()) {
var result = session.run(
"MATCH (s:%s {id:$conversationId}) RETURN count(s) as count".formatted(config.getSessionLabel()),
Map.of("conversationId", conversationId));
assertThat(result.single().get("count").asLong()).isZero();
}
}
@Test
void saveAllReplacesExistingMessages() {
var conversationId = UUID.randomUUID().toString();
// Save initial messages
List<Message> initialMessages = List.of(new UserMessage("Initial message 1"),
new UserMessage("Initial message 2"), new UserMessage("Initial message 3"));
chatMemoryRepository.saveAll(conversationId, initialMessages);
// Verify initial messages were saved
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
// Replace with new messages
List<Message> newMessages = List.of(new UserMessage("New message 1"), new UserMessage("New message 2"));
chatMemoryRepository.saveAll(conversationId, newMessages);
// Verify only new messages exist
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(2);
assertThat(retrievedMessages.get(0).getText()).isEqualTo("New message 1");
assertThat(retrievedMessages.get(1).getText()).isEqualTo("New message 2");
}
@Test
void handleMediaContent() {
var conversationId = UUID.randomUUID().toString();
MimeType textPlain = MimeType.valueOf("text/plain");
List<Media> media = List.of(Media.builder()
.name("some media")
.id(UUID.randomUUID().toString())
.mimeType(textPlain)
.data("hello".getBytes(StandardCharsets.UTF_8))
.build(), Media.builder().data(URI.create("http://www.example.com")).mimeType(textPlain).build());
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
chatMemoryRepository.saveAll(conversationId, List.<Message>of(userMessageWithMedia));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
UserMessage retrievedMessage = (UserMessage) retrievedMessages.get(0);
assertThat(retrievedMessage.getMedia()).hasSize(2);
assertThat(retrievedMessage.getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media);
}
@Test
void handleAssistantMessageWithToolCalls() {
var conversationId = UUID.randomUUID().toString();
AssistantMessage assistantMessage = new AssistantMessage("Message with tool calls", Map.of(),
List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"),
new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2")));
chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
AssistantMessage retrievedMessage = (AssistantMessage) retrievedMessages.get(0);
assertThat(retrievedMessage.getToolCalls()).hasSize(2);
assertThat(retrievedMessage.getToolCalls().get(0).id()).isEqualTo("id1");
assertThat(retrievedMessage.getToolCalls().get(1).id()).isEqualTo("id2");
}
@Test
void handleToolResponseMessage() {
var conversationId = UUID.randomUUID().toString();
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
Map.of("metadataKey", "metadataValue"));
chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
ToolResponseMessage retrievedMessage = (ToolResponseMessage) retrievedMessages.get(0);
assertThat(retrievedMessage.getResponses()).hasSize(2);
assertThat(retrievedMessage.getResponses().get(0).id()).isEqualTo("id1");
assertThat(retrievedMessage.getResponses().get(1).id()).isEqualTo("id2");
assertThat(retrievedMessage.getMetadata()).containsEntry("metadataKey", "metadataValue");
}
@Test
void saveAndFindSystemMessageWithMetadata() {
var conversationId = UUID.randomUUID().toString();
Map<String, Object> customMetadata = Map.of("priority", "high", "source", "test");
SystemMessage systemMessage = SystemMessage.builder()
.text("System message with custom metadata - " + conversationId)
.metadata(customMetadata)
.build();
chatMemoryRepository.saveAll(conversationId, List.of(systemMessage));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
Message retrievedMessage = retrievedMessages.get(0);
assertThat(retrievedMessage).isInstanceOf(SystemMessage.class);
assertThat(retrievedMessage.getText()).isEqualTo("System message with custom metadata - " + conversationId);
// Crucial assertion for the metadata
assertThat(retrievedMessage.getMetadata()).containsAllEntriesOf(customMetadata);
// Also check that the 'messageType' key is present (added by the repository)
assertThat(retrievedMessage.getMetadata()).containsEntry("messageType", MessageType.SYSTEM);
// Verify no extra unwanted metadata keys beyond what's expected
assertThat(retrievedMessage.getMetadata().keySet())
.containsExactlyInAnyOrderElementsOf(new ArrayList<>(customMetadata.keySet()) {
{
add("messageType");
}
});
}
@Test
void saveAllWithEmptyListClearsConversation() {
var conversationId = UUID.randomUUID().toString();
// 1. Setup: Create a conversation with some initial messages
UserMessage initialMessage1 = new UserMessage("Initial message 1");
AssistantMessage initialMessage2 = new AssistantMessage("Initial response 1");
chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2));
// Verify initial messages are there
List<Message> messagesAfterInitialSave = chatMemoryRepository.findByConversationId(conversationId);
assertThat(messagesAfterInitialSave).hasSize(2);
// 2. Action: Call saveAll with an empty list
chatMemoryRepository.saveAll(conversationId, Collections.emptyList());
// 3. Assertions:
// a) No messages should be found for the conversationId
List<Message> messagesAfterEmptySave = chatMemoryRepository.findByConversationId(conversationId);
assertThat(messagesAfterEmptySave).isEmpty();
// b) The conversationId itself should no longer be listed (because
// deleteByConversationId removes the session node)
List<String> conversationIds = chatMemoryRepository.findConversationIds();
assertThat(conversationIds).doesNotContain(conversationId);
// c) Verify directly in Neo4j that the conversation node is gone
try (Session session = driver.session()) {
Result result = session.run(
"MATCH (s:%s {id: $conversationId}) RETURN s".formatted(config.getSessionLabel()),
Map.of("conversationId", conversationId));
assertThat(result.hasNext()).isFalse(); // No conversation node should exist
}
}
@Test
void saveAndFindMessagesWithEmptyContentOrMetadata() {
var conversationId = UUID.randomUUID().toString();
UserMessage messageWithEmptyContent = new UserMessage("");
UserMessage messageWithEmptyMetadata = UserMessage.builder()
.text("Content with empty metadata")
.metadata(Collections.emptyMap())
.build();
List<Message> messagesToSave = List.of(messageWithEmptyContent, messageWithEmptyMetadata);
chatMemoryRepository.saveAll(conversationId, messagesToSave);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(2);
// Verify first message (empty content)
Message retrievedEmptyContentMsg = retrievedMessages.get(0);
assertThat(retrievedEmptyContentMsg).isInstanceOf(UserMessage.class);
assertThat(retrievedEmptyContentMsg.getText()).isEqualTo("");
assertThat(retrievedEmptyContentMsg.getMetadata()).containsEntry("messageType", MessageType.USER); // Default
// metadata
assertThat(retrievedEmptyContentMsg.getMetadata().keySet()).hasSize(1); // Only
// messageType
// Verify second message (empty metadata from input, should only have messageType
// after retrieval)
Message retrievedEmptyMetadataMsg = retrievedMessages.get(1);
assertThat(retrievedEmptyMetadataMsg).isInstanceOf(UserMessage.class);
assertThat(retrievedEmptyMetadataMsg.getText()).isEqualTo("Content with empty metadata");
assertThat(retrievedEmptyMetadataMsg.getMetadata()).containsEntry("messageType", MessageType.USER);
assertThat(retrievedEmptyMetadataMsg.getMetadata().keySet()).hasSize(1); // Only
// messageType
}
private Message createMessageByType(String content, MessageType messageType) {
return switch (messageType) {
case ASSISTANT -> new AssistantMessage(content);
case USER -> new UserMessage(content);
case SYSTEM -> new SystemMessage(content);
case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")));
};
}
/**
* Factory for creating Neo4j Driver instances.
*/
private static class Neo4jDriverFactory {
static Driver create(String boltUrl) {
return org.neo4j.driver.GraphDatabase.driver(boltUrl);
}
}
}

View File

@@ -64,7 +64,7 @@ If you'd rather create the `InMemoryChatMemoryRepository` manually, you can do s
ChatMemoryRepository repository = new InMemoryChatMemoryRepository();
----
=== JDBC Repository
=== JdbcChatMemoryRepository
`JdbcChatMemoryRepository` is a built-in implementation that uses JDBC to store messages in a relational database. It is suitable for applications that require persistent storage of chat memory.
@@ -127,6 +127,7 @@ ChatMemory chatMemory = MessageWindowChatMemory.builder()
| `spring.ai.chat.memory.repository.jdbc.initialize-schema` | Whether to initialize the schema on startup. | `true`
|===
==== Schema Initialization
The auto-configuration will automatically create the `ai_chat_memory` table using the JDBC driver. Currently, only PostgreSQL and MariaDB are supported.
@@ -135,6 +136,78 @@ You can disable the schema initialization by setting the property `spring.ai.cha
If your project uses a tool like Flyway or Liquibase to manage your database schemas, you can disable the schema initialization and refer to link:https://github.com/spring-projects/spring-ai/tree/main/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc[these SQL scripts] for configuring those tools to create the `ai_chat_memory` table.
=== Neo4j ChatMemoryRepository
`Neo4jChatMemoryRepository` is a built-in implementation that uses Neo4j to store chat messages as nodes and relationships in a property graph database. It is suitable for applications that want to leverage Neo4j's graph capabilities for chat memory persistence.
First, add the following dependency to your project:
[tabs]
======
Maven::
+
[source, xml]
----
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-chat-memory-neo4j</artifactId>
</dependency>
----
Gradle::
+
[source,groovy]
----
dependencies {
implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-neo4j'
}
----
======
Spring AI provides auto-configuration for the `Neo4jChatMemoryRepository`, which you can use directly in your application.
[source,java]
----
@Autowired
Neo4jChatMemoryRepository chatMemoryRepository;
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(10)
.build();
----
If you'd rather create the `Neo4jChatMemoryRepository` manually, you can do so by providing a Neo4j `Driver` instance:
[source,java]
----
ChatMemoryRepository chatMemoryRepository = Neo4jChatMemoryRepository.builder()
.driver(driver)
.build();
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(10)
.build();
----
==== Configuration Properties
[cols="2,5,1",stripes=even]
|===
|Property | Description | Default Value
| `spring.ai.chat.memory.neo4j.sessionLabel` | The label for the nodes that store conversation sessions | `Session`
| `spring.ai.chat.memory.neo4j.messageLabel` | The label for the nodes that store messages | `Message`
| `spring.ai.chat.memory.neo4j.toolCallLabel` | The label for nodes that store tool calls (e.g. in Assistant Messages) | `ToolCall`
| `spring.ai.chat.memory.neo4j.metadataLabel` | The label for nodes that store message metadata | `Metadata`
| `spring.ai.chat.memory.neo4j.toolResponseLabel` | The label for the nodes that store tool responses | `ToolResponse`
| `spring.ai.chat.memory.neo4j.mediaLabel` | The label for the nodes that store media associated with a message | `Media`
|===
==== Index Initialization
The Neo4j repository will automatically ensure that indexes are created for conversation IDs and message indices to optimize performance. If you use custom labels, indexes will be created for those labels as well. No schema initialization is required, but you should ensure your Neo4j instance is accessible to your application.
== Memory in Chat Client
When using the ChatClient API, you can provide a `ChatMemory` implementation to maintain conversation context across multiple interactions.