From 31da8f3dbb63909bab52c9855ea854a553a3349d Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Wed, 14 May 2025 16:34:33 +0800 Subject: [PATCH] fix: Fixed the incorrect SQL in getSelectMessagesSql of MysqlChatMemoryRepositoryDialect - Added tests Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../pom.xml | 18 ++ .../MysqlChatMemoryRepositoryDialect.java | 4 +- .../SqlServerChatMemoryRepositoryDialect.java | 2 +- .../AbstractJdbcChatMemoryRepositoryIT.java | 207 ++++++++++++++++++ .../jdbc/JdbcChatMemoryRepositoryMysqlIT.java | 41 ++++ .../JdbcChatMemoryRepositoryPostgresqlIT.java | 126 +---------- 6 files changed, 273 insertions(+), 125 deletions(-) create mode 100644 memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java create mode 100644 memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryMysqlIT.java diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml index 4f8364433..636473006 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml @@ -69,6 +69,13 @@ true + + com.microsoft.sqlserver + mssql-jdbc + test + true + + org.springframework.boot @@ -94,10 +101,21 @@ test + + org.testcontainers + mssqlserver + test + + org.testcontainers junit-jupiter test + + org.testcontainers + mssqlserver + test + diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/MysqlChatMemoryRepositoryDialect.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/MysqlChatMemoryRepositoryDialect.java index a41fd76ee..045bb1f5e 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/MysqlChatMemoryRepositoryDialect.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/MysqlChatMemoryRepositoryDialect.java @@ -17,7 +17,7 @@ package org.springframework.ai.chat.memory.repository.jdbc; /** - * Dialect for MySQL. + * MySQL dialect for chat memory repository. * * @author Mark Pollack * @since 1.0.0 @@ -26,7 +26,7 @@ public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositor @Override public String getSelectMessagesSql() { - return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp` DESC LIMIT ?"; + return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`"; } @Override diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/SqlServerChatMemoryRepositoryDialect.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/SqlServerChatMemoryRepositoryDialect.java index c4555de08..dcc477fb0 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/SqlServerChatMemoryRepositoryDialect.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/SqlServerChatMemoryRepositoryDialect.java @@ -26,7 +26,7 @@ public class SqlServerChatMemoryRepositoryDialect implements JdbcChatMemoryRepos @Override public String getSelectMessagesSql() { - return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp] ASC"; + return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]"; } @Override diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java new file mode 100644 index 000000000..0ec671d0e --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java @@ -0,0 +1,207 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.repository.jdbc; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.jdbc.core.JdbcTemplate; + +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +import javax.sql.DataSource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Base class for integration tests for {@link JdbcChatMemoryRepository}. + * + * @author Mark Pollack + */ +public abstract class AbstractJdbcChatMemoryRepositoryIT { + + @Autowired + protected ChatMemoryRepository chatMemoryRepository; + + @Autowired + protected JdbcTemplate jdbcTemplate; + + @Test + void correctChatMemoryRepositoryInstance() { + assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) + void saveMessagesSingleMessage(String content, MessageType messageType) { + String conversationId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + case SYSTEM -> new SystemMessage(content + " - " + conversationId); + case TOOL -> throw new IllegalArgumentException("TOOL message type not supported in this test"); + }; + + chatMemoryRepository.saveAll(conversationId, List.of(message)); + + // Use dialect to get the appropriate SQL query + JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()); + String selectSql = dialect.getSelectMessagesSql() + .replace("content, type", "conversation_id, content, type, timestamp"); + var result = jdbcTemplate.queryForMap(selectSql, conversationId); + + assertThat(result.size()).isEqualTo(4); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + + @Test + void saveMessagesMultipleMessages() { + String conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.saveAll(conversationId, messages); + + // Use dialect to get the appropriate SQL query + JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()); + String selectSql = dialect.getSelectMessagesSql() + .replace("content, type", "conversation_id, content, type, timestamp"); + var results = jdbcTemplate.queryForList(selectSql, conversationId); + + assertThat(results).hasSize(messages.size()); + + for (int i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + + var count = chatMemoryRepository.findByConversationId(conversationId).size(); + assertThat(count).isEqualTo(messages.size()); + + chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello"))); + + count = chatMemoryRepository.findByConversationId(conversationId).size(); + assertThat(count).isEqualTo(1); + } + + @Test + void findMessagesByConversationId() { + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.saveAll(conversationId, messages); + + var results = chatMemoryRepository.findByConversationId(conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); + } + + @Test + void deleteMessagesByConversationId() { + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.saveAll(conversationId, messages); + + chatMemoryRepository.deleteByConversationId(conversationId); + + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?", + Integer.class, conversationId); + + assertThat(count).isZero(); + } + + @Test + void testMessageOrder() { + // Create a repository using the from method to detect the dialect + JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder() + .jdbcTemplate(jdbcTemplate) + .dialect(JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource())) + .build(); + + var conversationId = UUID.randomUUID().toString(); + + // Create messages with very distinct content to make order obvious + var firstMessage = new UserMessage("1-First message"); + var secondMessage = new AssistantMessage("2-Second message"); + var thirdMessage = new UserMessage("3-Third message"); + var fourthMessage = new SystemMessage("4-Fourth message"); + + // Save messages in the expected order + List orderedMessages = List.of(firstMessage, secondMessage, thirdMessage, fourthMessage); + repository.saveAll(conversationId, orderedMessages); + + // Retrieve messages using the repository + List retrievedMessages = repository.findByConversationId(conversationId); + assertThat(retrievedMessages).hasSize(4); + + // Get the actual order from the retrieved messages + List retrievedContents = retrievedMessages.stream().map(Message::getText).collect(Collectors.toList()); + + // Messages should be in the original order (ASC) + assertThat(retrievedContents).containsExactly("1-First message", "2-Second message", "3-Third message", + "4-Fourth message"); + } + + /** + * Base configuration for all integration tests. + */ + @ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class }) + static abstract class BaseTestConfiguration { + + @Bean + ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) { + return JdbcChatMemoryRepository.builder() + .jdbcTemplate(jdbcTemplate) + .dialect(JdbcChatMemoryRepositoryDialect.from(dataSource)) + .build(); + } + + } + +} diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryMysqlIT.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryMysqlIT.java new file mode 100644 index 000000000..22b3b45c3 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryMysqlIT.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.repository.jdbc; + +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.TestPropertySource; +import org.springframework.test.context.jdbc.Sql; + +/** + * Integration tests for {@link JdbcChatMemoryRepository} with MySQL. + * + * @author Jonathan Leijendekker + * @author Thomas Vitale + * @author Mark Pollack + */ +@SpringBootTest(classes = JdbcChatMemoryRepositoryMysqlIT.TestConfiguration.class) +@TestPropertySource(properties = { "spring.datasource.url=jdbc:tc:mariadb:10.3.39:///" }) +@Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-mariadb.sql") +class JdbcChatMemoryRepositoryMysqlIT extends AbstractJdbcChatMemoryRepositoryIT { + + @SpringBootConfiguration + static class TestConfiguration extends BaseTestConfiguration { + + } + +} diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java index 59a1148f1..026ca4ed5 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java @@ -16,147 +16,38 @@ package org.springframework.ai.chat.memory.repository.jdbc; -import java.sql.Timestamp; import java.util.List; import java.util.UUID; import javax.sql.DataSource; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; import org.springframework.ai.chat.memory.ChatMemoryRepository; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; -import org.springframework.boot.autoconfigure.ImportAutoConfiguration; -import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; -import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.jdbc.Sql; -import org.springframework.transaction.support.TransactionTemplate; import static org.assertj.core.api.Assertions.assertThat; /** - * Integration tests for {@link JdbcChatMemoryRepository}. + * Integration tests for {@link JdbcChatMemoryRepository} with PostgreSQL. * * @author Jonathan Leijendekker * @author Thomas Vitale + * @author Mark Pollack */ @SpringBootTest(classes = JdbcChatMemoryRepositoryPostgresqlIT.TestConfiguration.class) @TestPropertySource(properties = "spring.datasource.url=jdbc:tc:postgresql:17:///") @Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-postgresql.sql") -class JdbcChatMemoryRepositoryPostgresqlIT { - - @Autowired - private ChatMemoryRepository chatMemoryRepository; - - @Autowired - private JdbcTemplate jdbcTemplate; - - @Test - void correctChatMemoryRepositoryInstance() { - assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class); - } - - @ParameterizedTest - @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) - void saveMessagesSingleMessage(String content, MessageType messageType) { - var conversationId = UUID.randomUUID().toString(); - var message = switch (messageType) { - case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); - case USER -> new UserMessage(content + " - " + conversationId); - case SYSTEM -> new SystemMessage(content + " - " + conversationId); - default -> throw new IllegalArgumentException("Type not supported: " + messageType); - }; - - chatMemoryRepository.saveAll(conversationId, List.of(message)); - - var query = "SELECT conversation_id, content, type, \"timestamp\" FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?"; - var result = jdbcTemplate.queryForMap(query, conversationId); - - assertThat(result.size()).isEqualTo(4); - assertThat(result.get("conversation_id")).isEqualTo(conversationId); - assertThat(result.get("content")).isEqualTo(message.getText()); - assertThat(result.get("type")).isEqualTo(messageType.name()); - assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); - } - - @Test - void saveMessagesMultipleMessages() { - var conversationId = UUID.randomUUID().toString(); - var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), - new UserMessage("Message from user - " + conversationId), - new SystemMessage("Message from system - " + conversationId)); - - chatMemoryRepository.saveAll(conversationId, messages); - - var query = "SELECT conversation_id, content, type, \"timestamp\" FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?"; - var results = jdbcTemplate.queryForList(query, conversationId); - - assertThat(results.size()).isEqualTo(messages.size()); - - for (var i = 0; i < messages.size(); i++) { - var message = messages.get(i); - var result = results.get(i); - - assertThat(result.get("conversation_id")).isNotNull(); - assertThat(result.get("conversation_id")).isEqualTo(conversationId); - assertThat(result.get("content")).isEqualTo(message.getText()); - assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); - assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); - } - - var count = chatMemoryRepository.findByConversationId(conversationId).size(); - assertThat(count).isEqualTo(messages.size()); - - chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello"))); - - count = chatMemoryRepository.findByConversationId(conversationId).size(); - assertThat(count).isEqualTo(1); - } - - @Test - void findMessagesByConversationId() { - var conversationId = UUID.randomUUID().toString(); - var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), - new AssistantMessage("Message from assistant 2 - " + conversationId), - new UserMessage("Message from user - " + conversationId), - new SystemMessage("Message from system - " + conversationId)); - - chatMemoryRepository.saveAll(conversationId, messages); - - var results = chatMemoryRepository.findByConversationId(conversationId); - - assertThat(results.size()).isEqualTo(messages.size()); - assertThat(results).isEqualTo(messages); - } - - @Test - void deleteMessagesByConversationId() { - var conversationId = UUID.randomUUID().toString(); - var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), - new UserMessage("Message from user - " + conversationId), - new SystemMessage("Message from system - " + conversationId)); - - chatMemoryRepository.saveAll(conversationId, messages); - - chatMemoryRepository.deleteByConversationId(conversationId); - - var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?", - Integer.class, conversationId); - - assertThat(count).isZero(); - } +class JdbcChatMemoryRepositoryPostgresqlIT extends AbstractJdbcChatMemoryRepositoryIT { @Test void repositoryWithExplicitTransactionManager() { @@ -187,16 +78,7 @@ class JdbcChatMemoryRepositoryPostgresqlIT { } @SpringBootConfiguration - @ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class }) - static class TestConfiguration { - - @Bean - ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) { - return JdbcChatMemoryRepository.builder() - .jdbcTemplate(jdbcTemplate) - .dialect(JdbcChatMemoryRepositoryDialect.from(dataSource)) - .build(); - } + static class TestConfiguration extends BaseTestConfiguration { @Bean ChatMemoryRepository chatMemoryRepositoryWithTxManager(JdbcTemplate jdbcTemplate, DataSource dataSource) {