diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests.java new file mode 100644 index 000000000..cd53f2bd8 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryDataSourceScriptDatabaseInitializerPostgresqlTests { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME) + .withDatabaseName("chat_memory_initializer_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", postgresContainer.getUsername()), + String.format("spring.datasource.password=%s", postgresContainer.getPassword())); + + @Test + void getSettings_shouldHaveSchemaLocations() { + this.contextRunner.run(context -> { + var dataSource = context.getBean(DataSource.class); + var settings = JdbcChatMemoryDataSourceScriptDatabaseInitializer.getSettings(dataSource); + + assertThat(settings.getSchemaLocations()) + .containsOnly("classpath:org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"); + }); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java index 554b2a3d8..3a7f66aae 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/main/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfiguration.java @@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.neo4j.autoconfigure; import org.neo4j.driver.Driver; -import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory; import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository; import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -29,19 +29,19 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; /** - * {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}. + * {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemoryRepository}. * * @author Enrico Rampazzo * @since 1.0.0 */ @AutoConfiguration(after = Neo4jAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class) -@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class }) +@ConditionalOnClass({ Neo4jChatMemoryRepository.class, Driver.class }) @EnableConfigurationProperties(Neo4jChatMemoryProperties.class) public class Neo4jChatMemoryAutoConfiguration { @Bean @ConditionalOnMissingBean - public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) { + public Neo4jChatMemoryRepository chatMemoryRepository(Neo4jChatMemoryProperties properties, Driver driver) { var builder = Neo4jChatMemoryConfig.builder() .withMediaLabel(properties.getMediaLabel()) @@ -52,7 +52,7 @@ public class Neo4jChatMemoryAutoConfiguration { .withToolResponseLabel(properties.getToolResponseLabel()) .withDriver(driver); - return Neo4jChatMemory.create(builder.build()); + return new Neo4jChatMemoryRepository(builder.build()); } } diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/test/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/test/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java similarity index 80% rename from auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/test/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfigurationIT.java rename to auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/test/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java index 7a4448784..5826a0ae9 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/test/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-neo4j/src/test/java/org/springframework/ai/model/chat/memory/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java @@ -23,12 +23,14 @@ import java.util.Map; import java.util.UUID; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryRepository; import org.testcontainers.containers.Neo4jContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.DockerImageName; -import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory; import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -51,7 +53,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 1.0.0 */ @Testcontainers -class Neo4jChatMemoryAutoConfigurationIT { +class Neo4jChatMemoryRepositoryAutoConfigurationIT { static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j"); @@ -67,31 +69,31 @@ class Neo4jChatMemoryAutoConfigurationIT { @Test void addAndGet() { this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()).run(context -> { - Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class); + ChatMemoryRepository memory = context.getBean(ChatMemoryRepository.class); String sessionId = UUID.randomUUID().toString(); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + assertThat(memory.findByConversationId(sessionId)).isEmpty(); UserMessage userMessage = new UserMessage("test question"); - memory.add(sessionId, userMessage); - List messages = memory.get(sessionId, Integer.MAX_VALUE); + memory.saveAll(sessionId, List.of(userMessage)); + List messages = memory.findByConversationId(sessionId); assertThat(messages).hasSize(1); assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage); - memory.clear(sessionId); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + memory.deleteByConversationId(sessionId); + assertThat(memory.findByConversationId(sessionId)).isEmpty(); AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(), List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments"))); - memory.add(sessionId, List.of(userMessage, assistantMessage)); - messages = memory.get(sessionId, Integer.MAX_VALUE); + memory.saveAll(sessionId, List.of(userMessage, assistantMessage)); + messages = memory.findByConversationId(sessionId); assertThat(messages).hasSize(2); - assertThat(messages.get(1)).isEqualTo(userMessage); + assertThat(messages.get(0)).isEqualTo(userMessage); - assertThat(messages.get(0)).isEqualTo(assistantMessage); - memory.clear(sessionId); + assertThat(messages.get(1)).isEqualTo(assistantMessage); + memory.deleteByConversationId(sessionId); MimeType textPlain = MimeType.valueOf("text/plain"); List media = List.of( Media.builder() @@ -102,28 +104,28 @@ class Neo4jChatMemoryAutoConfigurationIT { .build(), Media.builder().data(URI.create("http://www.google.com")).mimeType(textPlain).build()); UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build(); - memory.add(sessionId, userMessageWithMedia); + memory.saveAll(sessionId, List.of(userMessageWithMedia)); - messages = memory.get(sessionId, Integer.MAX_VALUE); + messages = memory.findByConversationId(sessionId); assertThat(messages.size()).isEqualTo(1); assertThat(messages.get(0)).isEqualTo(userMessageWithMedia); assertThat(((UserMessage) messages.get(0)).getMedia()).hasSize(2); assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator() .isEqualTo(media); - memory.clear(sessionId); + memory.deleteByConversationId(sessionId); ToolResponseMessage toolResponseMessage = new ToolResponseMessage( List.of(new ToolResponse("id", "name", "responseData"), new ToolResponse("id2", "name2", "responseData2")), Map.of("id", "id", "metadataKey", "metadata")); - memory.add(sessionId, toolResponseMessage); - messages = memory.get(sessionId, Integer.MAX_VALUE); + memory.saveAll(sessionId, List.of(toolResponseMessage)); + messages = memory.findByConversationId(sessionId); assertThat(messages.size()).isEqualTo(1); assertThat(messages.get(0)).isEqualTo(toolResponseMessage); - memory.clear(sessionId); + memory.deleteByConversationId(sessionId); SystemMessage sm = new SystemMessage("this is a System message"); - memory.add(sessionId, sm); - messages = memory.get(sessionId, Integer.MAX_VALUE); + memory.saveAll(sessionId, List.of(sm)); + messages = memory.findByConversationId(sessionId); assertThat(messages).hasSize(1); assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm); }); @@ -148,7 +150,7 @@ class Neo4jChatMemoryAutoConfigurationIT { propertyBase.formatted("toolresponselabel", toolResponseLabel), propertyBase.formatted("medialabel", mediaLabel)) .run(context -> { - Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class); + Neo4jChatMemoryRepository chatMemory = context.getBean(Neo4jChatMemoryRepository.class); Neo4jChatMemoryConfig config = chatMemory.getConfig(); assertThat(config.getMessageLabel()).isEqualTo(messageLabel); assertThat(config.getMediaLabel()).isEqualTo(mediaLabel); diff --git a/auto-configurations/models/embedding/observation/spring-ai-autoconfigure-model-embedding-observation/src/main/java/org/springframework/ai/model/embedding/observation/autoconfigure/EmbeddingObservationAutoConfiguration.java b/auto-configurations/models/embedding/observation/spring-ai-autoconfigure-model-embedding-observation/src/main/java/org/springframework/ai/model/embedding/observation/autoconfigure/EmbeddingObservationAutoConfiguration.java new file mode 100644 index 000000000..ee0902343 --- /dev/null +++ b/auto-configurations/models/embedding/observation/spring-ai-autoconfigure-model-embedding-observation/src/main/java/org/springframework/ai/model/embedding/observation/autoconfigure/EmbeddingObservationAutoConfiguration.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.embedding.observation.autoconfigure; + +import io.micrometer.core.instrument.MeterRegistry; + +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.context.annotation.Bean; + +/** + * Auto-configuration for Spring AI embedding model observations. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@AutoConfiguration( + afterName = "org.springframework.boot.actuate.autoconfigure.observation.ObservationAutoConfiguration") +@ConditionalOnClass(EmbeddingModel.class) +public class EmbeddingObservationAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean(MeterRegistry.class) + EmbeddingModelMeterObservationHandler embeddingModelMeterObservationHandler( + ObjectProvider meterRegistry) { + return new EmbeddingModelMeterObservationHandler(meterRegistry.getObject()); + } + +} diff --git a/auto-configurations/models/embedding/observation/spring-ai-autoconfigure-model-embedding-observation/src/test/java/org/springframework/ai/model/embedding/observation/autoconfigure/EmbeddingObservationAutoConfigurationTests.java b/auto-configurations/models/embedding/observation/spring-ai-autoconfigure-model-embedding-observation/src/test/java/org/springframework/ai/model/embedding/observation/autoconfigure/EmbeddingObservationAutoConfigurationTests.java new file mode 100644 index 000000000..10c9feb7c --- /dev/null +++ b/auto-configurations/models/embedding/observation/spring-ai-autoconfigure-model-embedding-observation/src/test/java/org/springframework/ai/model/embedding/observation/autoconfigure/EmbeddingObservationAutoConfigurationTests.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.embedding.observation.autoconfigure; + +import io.micrometer.core.instrument.composite.CompositeMeterRegistry; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link EmbeddingObservationAutoConfiguration}. + * + * @author Thomas Vitale + */ +class EmbeddingObservationAutoConfigurationTests { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(EmbeddingObservationAutoConfiguration.class)); + + @Test + void meterObservationHandlerEnabled() { + this.contextRunner.withBean(CompositeMeterRegistry.class) + .run(context -> assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class)); + } + + @Test + void meterObservationHandlerDisabled() { + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class)); + } + +} diff --git a/memory/spring-ai-model-chat-memory-neo4j/pom.xml b/memory/spring-ai-model-chat-memory-neo4j/pom.xml index 8f8248d7a..47e370364 100644 --- a/memory/spring-ai-model-chat-memory-neo4j/pom.xml +++ b/memory/spring-ai-model-chat-memory-neo4j/pom.xml @@ -50,6 +50,7 @@ spring-data-neo4j + org.springframework.boot spring-boot-starter-test @@ -68,6 +69,12 @@ spring-boot-testcontainers test + + + org.testcontainers + testcontainers + test + org.neo4j.driver @@ -79,6 +86,12 @@ neo4j test + + + org.testcontainers + junit-jupiter + test + diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java index 9d4b23877..f5409f93c 100644 --- a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java +++ b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java @@ -93,6 +93,33 @@ public final class Neo4jChatMemoryConfig { this.toolCallLabel = builder.toolCallLabel; this.metadataLabel = builder.metadataLabel; this.toolResponseLabel = builder.toolResponseLabel; + ensureIndexes(); + } + + /** + * Ensures that indexes exist on conversationId for Session nodes and index for + * Message nodes. This improves query performance for lookups and ordering. + */ + private void ensureIndexes() { + if (this.driver == null) { + logger.warn("Neo4j Driver is null, cannot ensure indexes."); + return; + } + try (var session = this.driver.session()) { + // Index for conversationId on Session nodes + String sessionIndexCypher = String.format( + "CREATE INDEX session_conversation_id_index IF NOT EXISTS FOR (n:%s) ON (n.conversationId)", + this.sessionLabel); + // Index for index on Message nodes + String messageIndexCypher = String + .format("CREATE INDEX message_index_index IF NOT EXISTS FOR (n:%s) ON (n.index)", this.messageLabel); + session.run(sessionIndexCypher); + session.run(messageIndexCypher); + logger.info("Ensured Neo4j indexes for conversationId and message index."); + } + catch (Exception e) { + logger.warn("Failed to ensure Neo4j indexes for chat memory: {}", e.getMessage()); + } } public static Builder builder() { diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java new file mode 100644 index 000000000..7aa3ee299 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java @@ -0,0 +1,293 @@ +package org.springframework.ai.chat.memory.neo4j; + +import org.neo4j.driver.Result; +import org.neo4j.driver.Session; +import org.neo4j.driver.Transaction; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.content.Media; +import org.springframework.ai.content.MediaContent; +import org.springframework.util.MimeType; + +import java.net.URI; +import java.util.*; + +/** + * An implementation of {@link ChatMemoryRepository} for Neo4J + * + * @author Enrico Rampazzo + * @since 1.0.0 + */ + +public class Neo4jChatMemoryRepository implements ChatMemoryRepository { + + private final Neo4jChatMemoryConfig config; + + public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) { + this.config = config; + } + + @Override + public List findConversationIds() { + try (var session = config.getDriver().session()) { + return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel())) + .stream() + .map(r -> r.get("conversation.id").asString()) + .toList(); + } + } + + @Override + public List findByConversationId(String conversationId) { + String statementBuilder = """ + MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s) + WITH m + OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s) + OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) + WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC + RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias + ORDER BY m.idx ASC + """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), + this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(), + this.config.getToolCallLabel()); + Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId)); + return res.stream().map(record -> { + Map messageMap = record.get("m").asMap(); + String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); + Message message = null; + List mediaList = List.of(); + if (!record.get("medias").isNull()) { + mediaList = getMedia(record); + } + if (msgType.equals(MessageType.USER.getValue())) { + message = buildUserMessage(record, messageMap, mediaList); + } + if (msgType.equals(MessageType.ASSISTANT.getValue())) { + message = buildAssistantMessage(record, messageMap, mediaList); + } + if (msgType.equals(MessageType.SYSTEM.getValue())) { + SystemMessage.Builder systemMessageBuilder = SystemMessage.builder() + .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); + if (!record.get("metadata").isNull()) { + Map retrievedMetadata = record.get("metadata").asMap(); + systemMessageBuilder.metadata(retrievedMetadata); + } + message = systemMessageBuilder.build(); + } + if (msgType.equals(MessageType.TOOL.getValue())) { + message = buildToolMessage(record); + } + if (message == null) { + throw new IllegalArgumentException("%s messages are not supported" + .formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); + } + message.getMetadata().put("messageType", message.getMessageType()); + return message; + }).toList(); + + } + + @Override + public void saveAll(String conversationId, List messages) { + // First delete existing messages for this conversation + deleteByConversationId(conversationId); + + // Then add the new messages + try (Session s = this.config.getDriver().session()) { + try (Transaction t = s.beginTransaction()) { + for (Message m : messages) { + addMessageToTransaction(t, conversationId, m); + } + t.commit(); + } + } + } + + @Override + public void deleteByConversationId(String conversationId) { + // First delete all messages and related nodes + String deleteMessagesStatement = """ + MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s) + OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s) + OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) + OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) + OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) + DETACH DELETE m, metadata, media, tr, tc + """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), + this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(), + this.config.getToolCallLabel()); + + // Then delete the conversation node itself + String deleteConversationStatement = """ + MATCH (s:%s {id:$conversationId}) + DETACH DELETE s + """.formatted(this.config.getSessionLabel()); + + try (Session s = this.config.getDriver().session()) { + try (Transaction t = s.beginTransaction()) { + // First delete messages + t.run(deleteMessagesStatement, Map.of("conversationId", conversationId)); + // Then delete the conversation node + t.run(deleteConversationStatement, Map.of("conversationId", conversationId)); + t.commit(); + } + } + } + + public Neo4jChatMemoryConfig getConfig() { + return this.config; + } + + private Message buildToolMessage(org.neo4j.driver.Record record) { + Message message; + message = new ToolResponseMessage(record.get("toolResponses").asList(v -> { + Map trMap = v.asMap(); + return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()), + (String) trMap.get(ToolResponseAttributes.NAME.getValue()), + (String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue())); + }), record.get("metadata").asMap()); + return message; + } + + private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, + List mediaList) { + Message message; + message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), + record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> { + var toolCallMap = v.asMap(); + return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), + (String) toolCallMap.get("type"), (String) toolCallMap.get("name"), + (String) toolCallMap.get("arguments")); + }), mediaList); + return message; + } + + private Message buildUserMessage(org.neo4j.driver.Record record, Map messageMap, + List mediaList) { + Message message; + Map metadata = record.get("metadata").asMap(); + message = UserMessage.builder() + .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()) + .media(mediaList) + .metadata(metadata) + .build(); + return message; + } + + private List getMedia(org.neo4j.driver.Record record) { + List mediaList; + mediaList = record.get("medias").asList(v -> { + Map mediaMap = v.asMap(); + var mediaBuilder = Media.builder() + .name((String) mediaMap.get(MediaAttributes.NAME.getValue())) + .id(Optional.ofNullable(mediaMap.get(MediaAttributes.ID.getValue())).map(Object::toString).orElse(null)) + .mimeType(MimeType.valueOf(mediaMap.get(MediaAttributes.MIME_TYPE.getValue()).toString())); + if (mediaMap.get(MediaAttributes.DATA.getValue()) instanceof String stringData) { + mediaBuilder.data(URI.create(stringData)); + } + else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) { + mediaBuilder.data(mediaMap.get(MediaAttributes.DATA.getValue())); + } + return mediaBuilder.build(); + + }); + return mediaList; + } + + private void addMessageToTransaction(Transaction t, String conversationId, Message message) { + Map queryParameters = new HashMap<>(); + queryParameters.put("conversationId", conversationId); + StringBuilder statementBuilder = new StringBuilder(); + statementBuilder.append(""" + MERGE (s:%s {id:$conversationId}) WITH s + OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s + CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties + SET msg.idx = totalMsg + 1 + """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), + this.config.getMessageLabel())); + Map attributes = new HashMap<>(); + + attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue()); + attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText()); + attributes.put("id", UUID.randomUUID().toString()); + queryParameters.put("messageProperties", attributes); + + if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) { + statementBuilder.append(""" + WITH msg + CREATE (metadataNode:%s) + CREATE (msg)-[:HAS_METADATA]->(metadataNode) + SET metadataNode = $metadata + """.formatted(this.config.getMetadataLabel())); + Map metadataCopy = new HashMap<>(message.getMetadata()); + metadataCopy.remove("messageType"); + queryParameters.put("metadata", metadataCopy); + } + if (message instanceof AssistantMessage assistantMessage) { + if (assistantMessage.hasToolCalls()) { + statementBuilder.append(""" + WITH msg + FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc + CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall)) + """.formatted(this.config.getToolCallLabel())); + List> toolCallMaps = new ArrayList<>(); + for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) { + AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i); + toolCallMaps + .add(Map.of(ToolCallAttributes.ID.getValue(), tc.id(), ToolCallAttributes.NAME.getValue(), + tc.name(), ToolCallAttributes.ARGUMENTS.getValue(), tc.arguments(), + ToolCallAttributes.TYPE.getValue(), tc.type(), ToolCallAttributes.IDX.getValue(), i)); + } + queryParameters.put("toolCalls", toolCallMaps); + } + } + if (message instanceof ToolResponseMessage toolResponseMessage) { + List toolResponses = toolResponseMessage.getResponses(); + List> toolResponseMaps = new ArrayList<>(); + for (int i = 0; i < Optional.ofNullable(toolResponses).orElse(List.of()).size(); i++) { + var toolResponse = toolResponses.get(i); + Map toolResponseMap = Map.of(ToolResponseAttributes.ID.getValue(), toolResponse.id(), + ToolResponseAttributes.NAME.getValue(), toolResponse.name(), + ToolResponseAttributes.RESPONSE_DATA.getValue(), toolResponse.responseData(), + ToolResponseAttributes.IDX.getValue(), Integer.toString(i)); + toolResponseMaps.add(toolResponseMap); + } + statementBuilder.append(""" + WITH msg + FOREACH(tr IN $toolResponses | CREATE (tm:%s) + SET tm = tr + MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm)) + """.formatted(this.config.getToolResponseLabel())); + queryParameters.put("toolResponses", toolResponseMaps); + } + if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) { + List> mediaNodes = convertMediaToMap(messageWithMedia.getMedia()); + statementBuilder.append(""" + WITH msg + UNWIND $media AS m + CREATE (media:%s) SET media = m + WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media) + """.formatted(this.config.getMediaLabel())); + queryParameters.put("media", mediaNodes); + } + t.run(statementBuilder.toString(), queryParameters); + } + + private List> convertMediaToMap(List media) { + List> mediaMaps = new ArrayList<>(); + for (int i = 0; i < media.size(); i++) { + Map mediaMap = new HashMap<>(); + Media m = media.get(i); + mediaMap.put(MediaAttributes.ID.getValue(), m.getId()); + mediaMap.put(MediaAttributes.MIME_TYPE.getValue(), m.getMimeType().toString()); + mediaMap.put(MediaAttributes.NAME.getValue(), m.getName()); + mediaMap.put(MediaAttributes.DATA.getValue(), m.getData()); + mediaMap.put(MediaAttributes.IDX.getValue(), i); + mediaMaps.add(mediaMap); + } + return mediaMaps; + } + +} diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfigIT.java b/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfigIT.java new file mode 100644 index 000000000..faf6a6130 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfigIT.java @@ -0,0 +1,91 @@ +package org.springframework.ai.chat.memory.neo4j; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.AfterAll; +import org.neo4j.driver.Driver; +import org.neo4j.driver.GraphDatabase; +import org.neo4j.driver.Session; +import org.neo4j.driver.Result; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; + +@Testcontainers +class Neo4jChatMemoryConfigIT { + + @Container + static final Neo4jContainer neo4jContainer = new Neo4jContainer<>("neo4j:5").withoutAuthentication(); + + static Driver driver; + + @BeforeAll + static void setupDriver() { + driver = GraphDatabase.driver(neo4jContainer.getBoltUrl()); + } + + @AfterAll + static void closeDriver() { + if (driver != null) + driver.close(); + } + + @Test + void shouldCreateRequiredIndexes() { + // Given + Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder().withDriver(driver).build(); + // When + try (Session session = driver.session()) { + Result result = session.run("SHOW INDEXES"); + boolean sessionIndexFound = false; + boolean messageIndexFound = false; + while (result.hasNext()) { + var record = result.next(); + String name = record.get("name").asString(); + if ("session_conversation_id_index".equals(name)) + sessionIndexFound = true; + if ("message_index_index".equals(name)) + messageIndexFound = true; + } + // Then + assertThat(sessionIndexFound).isTrue(); + assertThat(messageIndexFound).isTrue(); + } + } + + @Test + void builderShouldSetCustomLabels() { + String customSessionLabel = "ChatSession"; + String customMessageLabel = "ChatMessage"; + Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder() + .withDriver(driver) + .withSessionLabel(customSessionLabel) + .withMessageLabel(customMessageLabel) + .build(); + assertThat(config.getSessionLabel()).isEqualTo(customSessionLabel); + assertThat(config.getMessageLabel()).isEqualTo(customMessageLabel); + } + + @Test + void gettersShouldReturnConfiguredValues() { + Neo4jChatMemoryConfig config = Neo4jChatMemoryConfig.builder() + .withDriver(driver) + .withSessionLabel("Session") + .withToolCallLabel("ToolCall") + .withMetadataLabel("Metadata") + .withMessageLabel("Message") + .withToolResponseLabel("ToolResponse") + .withMediaLabel("Media") + .build(); + assertThat(config.getSessionLabel()).isEqualTo("Session"); + assertThat(config.getToolCallLabel()).isEqualTo("ToolCall"); + assertThat(config.getMetadataLabel()).isEqualTo("Metadata"); + assertThat(config.getMessageLabel()).isEqualTo("Message"); + assertThat(config.getToolResponseLabel()).isEqualTo("ToolResponse"); + assertThat(config.getMediaLabel()).isEqualTo("Media"); + assertThat(config.getDriver()).isNotNull(); + } + +} diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java b/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java new file mode 100644 index 000000000..e01b9caeb --- /dev/null +++ b/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java @@ -0,0 +1,420 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.neo4j; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; +import org.neo4j.driver.Session; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.content.Media; +import org.springframework.util.MimeType; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link Neo4jChatMemoryRepository}. + * + * @author Enrico Rampazzo + * @since 1.0.0 + */ +@Testcontainers +class Neo4jChatMemoryRepositoryIT { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j"); + + @SuppressWarnings({ "rawtypes", "resource" }) + @Container + static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5")) + .withoutAuthentication() + .withExposedPorts(7474, 7687); + + private ChatMemoryRepository chatMemoryRepository; + + private Driver driver; + + private Neo4jChatMemoryConfig config; + + @BeforeEach + void setUp() { + driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl()); + config = Neo4jChatMemoryConfig.builder().withDriver(driver).build(); + chatMemoryRepository = new Neo4jChatMemoryRepository(config); + } + + @AfterEach + void tearDown() { + // Clean up all data after each test + try (Session session = driver.session()) { + session.run("MATCH (n) DETACH DELETE n"); + } + driver.close(); + } + + @Test + void correctChatMemoryRepositoryInstance() { + assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class); + assertThat(chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM", + "Message from tool,TOOL" }) + void saveAndFindSingleMessage(String content, MessageType messageType) { + var conversationId = UUID.randomUUID().toString(); + Message message = createMessageByType(content + " - " + conversationId, messageType); + + chatMemoryRepository.saveAll(conversationId, List.of(message)); + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + + assertThat(retrievedMessages).hasSize(1); + + Message retrievedMessage = retrievedMessages.get(0); + assertThat(retrievedMessage.getMessageType()).isEqualTo(messageType); + + if (messageType != MessageType.TOOL) { + assertThat(retrievedMessage.getText()).isEqualTo(message.getText()); + } + + // Verify directly in the database + try (Session session = driver.session()) { + var result = session.run( + "MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count" + .formatted(config.getSessionLabel(), config.getMessageLabel()), + Map.of("conversationId", conversationId)); + assertThat(result.single().get("count").asLong()).isEqualTo(1); + } + } + + @Test + void saveAndFindMultipleMessages() { + var conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId), + new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")))); + + chatMemoryRepository.saveAll(conversationId, messages); + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + + assertThat(retrievedMessages).hasSize(messages.size()); + + // Verify the order is preserved (ascending by index) + for (int i = 0; i < messages.size(); i++) { + if (messages.get(i).getMessageType() != MessageType.TOOL) { + assertThat(retrievedMessages.get(i).getText()).isEqualTo(messages.get(i).getText()); + } + assertThat(retrievedMessages.get(i).getMessageType()).isEqualTo(messages.get(i).getMessageType()); + } + } + + @Test + void verifyMessageOrdering() { + var conversationId = UUID.randomUUID().toString(); + List messages = new ArrayList<>(); + + // Add messages in a specific order + for (int i = 1; i <= 5; i++) { + messages.add(new UserMessage("Message " + i)); + } + + chatMemoryRepository.saveAll(conversationId, messages); + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + + assertThat(retrievedMessages).hasSize(messages.size()); + + // Verify that messages are returned in ascending order (oldest first) + for (int i = 0; i < messages.size(); i++) { + assertThat(retrievedMessages.get(i).getText()).isEqualTo("Message " + (i + 1)); + } + } + + @Test + void findConversationIds() { + // Create multiple conversations + var conversationId1 = UUID.randomUUID().toString(); + var conversationId2 = UUID.randomUUID().toString(); + var conversationId3 = UUID.randomUUID().toString(); + + chatMemoryRepository.saveAll(conversationId1, List.of(new UserMessage("Message for conversation 1"))); + chatMemoryRepository.saveAll(conversationId2, List.of(new UserMessage("Message for conversation 2"))); + chatMemoryRepository.saveAll(conversationId3, List.of(new UserMessage("Message for conversation 3"))); + + List conversationIds = chatMemoryRepository.findConversationIds(); + + assertThat(conversationIds).hasSize(3); + assertThat(conversationIds).contains(conversationId1, conversationId2, conversationId3); + } + + @Test + void deleteByConversationId() { + var conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("Message from assistant"), + new UserMessage("Message from user"), new SystemMessage("Message from system")); + + chatMemoryRepository.saveAll(conversationId, messages); + + // Verify messages were saved + assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3); + + // Delete the conversation + chatMemoryRepository.deleteByConversationId(conversationId); + + // Verify messages were deleted + assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty(); + + // Verify directly in the database + try (Session session = driver.session()) { + var result = session.run( + "MATCH (s:%s {id:$conversationId}) RETURN count(s) as count".formatted(config.getSessionLabel()), + Map.of("conversationId", conversationId)); + assertThat(result.single().get("count").asLong()).isZero(); + } + } + + @Test + void saveAllReplacesExistingMessages() { + var conversationId = UUID.randomUUID().toString(); + + // Save initial messages + List initialMessages = List.of(new UserMessage("Initial message 1"), + new UserMessage("Initial message 2"), new UserMessage("Initial message 3")); + chatMemoryRepository.saveAll(conversationId, initialMessages); + + // Verify initial messages were saved + assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3); + + // Replace with new messages + List newMessages = List.of(new UserMessage("New message 1"), new UserMessage("New message 2")); + chatMemoryRepository.saveAll(conversationId, newMessages); + + // Verify only new messages exist + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + assertThat(retrievedMessages).hasSize(2); + assertThat(retrievedMessages.get(0).getText()).isEqualTo("New message 1"); + assertThat(retrievedMessages.get(1).getText()).isEqualTo("New message 2"); + } + + @Test + void handleMediaContent() { + var conversationId = UUID.randomUUID().toString(); + + MimeType textPlain = MimeType.valueOf("text/plain"); + List media = List.of(Media.builder() + .name("some media") + .id(UUID.randomUUID().toString()) + .mimeType(textPlain) + .data("hello".getBytes(StandardCharsets.UTF_8)) + .build(), Media.builder().data(URI.create("http://www.example.com")).mimeType(textPlain).build()); + + UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build(); + + chatMemoryRepository.saveAll(conversationId, List.of(userMessageWithMedia)); + + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + assertThat(retrievedMessages).hasSize(1); + + UserMessage retrievedMessage = (UserMessage) retrievedMessages.get(0); + assertThat(retrievedMessage.getMedia()).hasSize(2); + assertThat(retrievedMessage.getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media); + } + + @Test + void handleAssistantMessageWithToolCalls() { + var conversationId = UUID.randomUUID().toString(); + + AssistantMessage assistantMessage = new AssistantMessage("Message with tool calls", Map.of(), + List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"), + new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2"))); + + chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage)); + + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + assertThat(retrievedMessages).hasSize(1); + + AssistantMessage retrievedMessage = (AssistantMessage) retrievedMessages.get(0); + assertThat(retrievedMessage.getToolCalls()).hasSize(2); + assertThat(retrievedMessage.getToolCalls().get(0).id()).isEqualTo("id1"); + assertThat(retrievedMessage.getToolCalls().get(1).id()).isEqualTo("id2"); + } + + @Test + void handleToolResponseMessage() { + var conversationId = UUID.randomUUID().toString(); + + ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List + .of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")), + Map.of("metadataKey", "metadataValue")); + + chatMemoryRepository.saveAll(conversationId, List.of(toolResponseMessage)); + + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + assertThat(retrievedMessages).hasSize(1); + + ToolResponseMessage retrievedMessage = (ToolResponseMessage) retrievedMessages.get(0); + assertThat(retrievedMessage.getResponses()).hasSize(2); + assertThat(retrievedMessage.getResponses().get(0).id()).isEqualTo("id1"); + assertThat(retrievedMessage.getResponses().get(1).id()).isEqualTo("id2"); + assertThat(retrievedMessage.getMetadata()).containsEntry("metadataKey", "metadataValue"); + } + + @Test + void saveAndFindSystemMessageWithMetadata() { + var conversationId = UUID.randomUUID().toString(); + Map customMetadata = Map.of("priority", "high", "source", "test"); + + SystemMessage systemMessage = SystemMessage.builder() + .text("System message with custom metadata - " + conversationId) + .metadata(customMetadata) + .build(); + + chatMemoryRepository.saveAll(conversationId, List.of(systemMessage)); + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + + assertThat(retrievedMessages).hasSize(1); + Message retrievedMessage = retrievedMessages.get(0); + + assertThat(retrievedMessage).isInstanceOf(SystemMessage.class); + assertThat(retrievedMessage.getText()).isEqualTo("System message with custom metadata - " + conversationId); + // Crucial assertion for the metadata + assertThat(retrievedMessage.getMetadata()).containsAllEntriesOf(customMetadata); + // Also check that the 'messageType' key is present (added by the repository) + assertThat(retrievedMessage.getMetadata()).containsEntry("messageType", MessageType.SYSTEM); + // Verify no extra unwanted metadata keys beyond what's expected + assertThat(retrievedMessage.getMetadata().keySet()) + .containsExactlyInAnyOrderElementsOf(new ArrayList<>(customMetadata.keySet()) { + { + add("messageType"); + } + }); + } + + @Test + void saveAllWithEmptyListClearsConversation() { + var conversationId = UUID.randomUUID().toString(); + + // 1. Setup: Create a conversation with some initial messages + UserMessage initialMessage1 = new UserMessage("Initial message 1"); + AssistantMessage initialMessage2 = new AssistantMessage("Initial response 1"); + chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2)); + + // Verify initial messages are there + List messagesAfterInitialSave = chatMemoryRepository.findByConversationId(conversationId); + assertThat(messagesAfterInitialSave).hasSize(2); + + // 2. Action: Call saveAll with an empty list + chatMemoryRepository.saveAll(conversationId, Collections.emptyList()); + + // 3. Assertions: + // a) No messages should be found for the conversationId + List messagesAfterEmptySave = chatMemoryRepository.findByConversationId(conversationId); + assertThat(messagesAfterEmptySave).isEmpty(); + + // b) The conversationId itself should no longer be listed (because + // deleteByConversationId removes the session node) + List conversationIds = chatMemoryRepository.findConversationIds(); + assertThat(conversationIds).doesNotContain(conversationId); + + // c) Verify directly in Neo4j that the conversation node is gone + try (Session session = driver.session()) { + Result result = session.run( + "MATCH (s:%s {id: $conversationId}) RETURN s".formatted(config.getSessionLabel()), + Map.of("conversationId", conversationId)); + assertThat(result.hasNext()).isFalse(); // No conversation node should exist + } + } + + @Test + void saveAndFindMessagesWithEmptyContentOrMetadata() { + var conversationId = UUID.randomUUID().toString(); + + UserMessage messageWithEmptyContent = new UserMessage(""); + UserMessage messageWithEmptyMetadata = UserMessage.builder() + .text("Content with empty metadata") + .metadata(Collections.emptyMap()) + .build(); + + List messagesToSave = List.of(messageWithEmptyContent, messageWithEmptyMetadata); + chatMemoryRepository.saveAll(conversationId, messagesToSave); + + List retrievedMessages = chatMemoryRepository.findByConversationId(conversationId); + assertThat(retrievedMessages).hasSize(2); + + // Verify first message (empty content) + Message retrievedEmptyContentMsg = retrievedMessages.get(0); + assertThat(retrievedEmptyContentMsg).isInstanceOf(UserMessage.class); + assertThat(retrievedEmptyContentMsg.getText()).isEqualTo(""); + assertThat(retrievedEmptyContentMsg.getMetadata()).containsEntry("messageType", MessageType.USER); // Default + // metadata + assertThat(retrievedEmptyContentMsg.getMetadata().keySet()).hasSize(1); // Only + // messageType + + // Verify second message (empty metadata from input, should only have messageType + // after retrieval) + Message retrievedEmptyMetadataMsg = retrievedMessages.get(1); + assertThat(retrievedEmptyMetadataMsg).isInstanceOf(UserMessage.class); + assertThat(retrievedEmptyMetadataMsg.getText()).isEqualTo("Content with empty metadata"); + assertThat(retrievedEmptyMetadataMsg.getMetadata()).containsEntry("messageType", MessageType.USER); + assertThat(retrievedEmptyMetadataMsg.getMetadata().keySet()).hasSize(1); // Only + // messageType + } + + private Message createMessageByType(String content, MessageType messageType) { + return switch (messageType) { + case ASSISTANT -> new AssistantMessage(content); + case USER -> new UserMessage(content); + case SYSTEM -> new SystemMessage(content); + case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))); + }; + } + + /** + * Factory for creating Neo4j Driver instances. + */ + private static class Neo4jDriverFactory { + + static Driver create(String boltUrl) { + return org.neo4j.driver.GraphDatabase.driver(boltUrl); + } + + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc index 819de5cb2..8b1427f3f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc @@ -64,7 +64,7 @@ If you'd rather create the `InMemoryChatMemoryRepository` manually, you can do s ChatMemoryRepository repository = new InMemoryChatMemoryRepository(); ---- -=== JDBC Repository +=== JdbcChatMemoryRepository `JdbcChatMemoryRepository` is a built-in implementation that uses JDBC to store messages in a relational database. It is suitable for applications that require persistent storage of chat memory. @@ -127,6 +127,7 @@ ChatMemory chatMemory = MessageWindowChatMemory.builder() | `spring.ai.chat.memory.repository.jdbc.initialize-schema` | Whether to initialize the schema on startup. | `true` |=== + ==== Schema Initialization The auto-configuration will automatically create the `ai_chat_memory` table using the JDBC driver. Currently, only PostgreSQL and MariaDB are supported. @@ -135,6 +136,78 @@ You can disable the schema initialization by setting the property `spring.ai.cha If your project uses a tool like Flyway or Liquibase to manage your database schemas, you can disable the schema initialization and refer to link:https://github.com/spring-projects/spring-ai/tree/main/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc[these SQL scripts] for configuring those tools to create the `ai_chat_memory` table. +=== Neo4j ChatMemoryRepository + +`Neo4jChatMemoryRepository` is a built-in implementation that uses Neo4j to store chat messages as nodes and relationships in a property graph database. It is suitable for applications that want to leverage Neo4j's graph capabilities for chat memory persistence. + +First, add the following dependency to your project: + +[tabs] +====== +Maven:: ++ +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-chat-memory-neo4j + +---- + +Gradle:: ++ +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-neo4j' +} +---- +====== + +Spring AI provides auto-configuration for the `Neo4jChatMemoryRepository`, which you can use directly in your application. + +[source,java] +---- +@Autowired +Neo4jChatMemoryRepository chatMemoryRepository; + +ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(10) + .build(); +---- + +If you'd rather create the `Neo4jChatMemoryRepository` manually, you can do so by providing a Neo4j `Driver` instance: + +[source,java] +---- +ChatMemoryRepository chatMemoryRepository = Neo4jChatMemoryRepository.builder() + .driver(driver) + .build(); + +ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(10) + .build(); +---- + +==== Configuration Properties + +[cols="2,5,1",stripes=even] +|=== +|Property | Description | Default Value +| `spring.ai.chat.memory.neo4j.sessionLabel` | The label for the nodes that store conversation sessions | `Session` +| `spring.ai.chat.memory.neo4j.messageLabel` | The label for the nodes that store messages | `Message` +| `spring.ai.chat.memory.neo4j.toolCallLabel` | The label for nodes that store tool calls (e.g. in Assistant Messages) | `ToolCall` +| `spring.ai.chat.memory.neo4j.metadataLabel` | The label for nodes that store message metadata | `Metadata` +| `spring.ai.chat.memory.neo4j.toolResponseLabel` | The label for the nodes that store tool responses | `ToolResponse` +| `spring.ai.chat.memory.neo4j.mediaLabel` | The label for the nodes that store media associated with a message | `Media` +|=== + +==== Index Initialization + +The Neo4j repository will automatically ensure that indexes are created for conversation IDs and message indices to optimize performance. If you use custom labels, indexes will be created for those labels as well. No schema initialization is required, but you should ensure your Neo4j instance is accessible to your application. + == Memory in Chat Client When using the ChatClient API, you can provide a `ChatMemory` implementation to maintain conversation context across multiple interactions.