fix: Fixed the incorrect SQL in getSelectMessagesSql of MysqlChatMemoryRepositoryDialect

- Added tests

Signed-off-by: Sun Yuhan <1085481446@qq.com>
This commit is contained in:
Sun Yuhan
2025-05-14 16:34:33 +08:00
committed by Mark Pollack
parent 0e1bb52ab7
commit 31da8f3dbb
6 changed files with 273 additions and 125 deletions

View File

@@ -69,6 +69,13 @@
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.microsoft.sqlserver</groupId>
<artifactId>mssql-jdbc</artifactId>
<scope>test</scope>
<optional>true</optional>
</dependency>
<!-- TESTING -->
<dependency>
<groupId>org.springframework.boot</groupId>
@@ -94,10 +101,21 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>mssqlserver</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>mssqlserver</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -17,7 +17,7 @@
package org.springframework.ai.chat.memory.repository.jdbc;
/**
* Dialect for MySQL.
* MySQL dialect for chat memory repository.
*
* @author Mark Pollack
* @since 1.0.0
@@ -26,7 +26,7 @@ public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositor
@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp` DESC LIMIT ?";
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
}
@Override

View File

@@ -26,7 +26,7 @@ public class SqlServerChatMemoryRepositoryDialect implements JdbcChatMemoryRepos
@Override
public String getSelectMessagesSql() {
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp] ASC";
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]";
}
@Override

View File

@@ -0,0 +1,207 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.memory.repository.jdbc;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.jdbc.core.JdbcTemplate;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Base class for integration tests for {@link JdbcChatMemoryRepository}.
*
* @author Mark Pollack
*/
public abstract class AbstractJdbcChatMemoryRepositoryIT {
@Autowired
protected ChatMemoryRepository chatMemoryRepository;
@Autowired
protected JdbcTemplate jdbcTemplate;
@Test
void correctChatMemoryRepositoryInstance() {
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
}
@ParameterizedTest
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" })
void saveMessagesSingleMessage(String content, MessageType messageType) {
String conversationId = UUID.randomUUID().toString();
var message = switch (messageType) {
case ASSISTANT -> new AssistantMessage(content + " - " + conversationId);
case USER -> new UserMessage(content + " - " + conversationId);
case SYSTEM -> new SystemMessage(content + " - " + conversationId);
case TOOL -> throw new IllegalArgumentException("TOOL message type not supported in this test");
};
chatMemoryRepository.saveAll(conversationId, List.of(message));
// Use dialect to get the appropriate SQL query
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
String selectSql = dialect.getSelectMessagesSql()
.replace("content, type", "conversation_id, content, type, timestamp");
var result = jdbcTemplate.queryForMap(selectSql, conversationId);
assertThat(result.size()).isEqualTo(4);
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
assertThat(result.get("content")).isEqualTo(message.getText());
assertThat(result.get("type")).isEqualTo(messageType.name());
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
}
@Test
void saveMessagesMultipleMessages() {
String conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
// Use dialect to get the appropriate SQL query
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
String selectSql = dialect.getSelectMessagesSql()
.replace("content, type", "conversation_id, content, type, timestamp");
var results = jdbcTemplate.queryForList(selectSql, conversationId);
assertThat(results).hasSize(messages.size());
for (int i = 0; i < messages.size(); i++) {
var message = messages.get(i);
var result = results.get(i);
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
assertThat(result.get("content")).isEqualTo(message.getText());
assertThat(result.get("type")).isEqualTo(message.getMessageType().name());
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
}
var count = chatMemoryRepository.findByConversationId(conversationId).size();
assertThat(count).isEqualTo(messages.size());
chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
count = chatMemoryRepository.findByConversationId(conversationId).size();
assertThat(count).isEqualTo(1);
}
@Test
void findMessagesByConversationId() {
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
new AssistantMessage("Message from assistant 2 - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
var results = chatMemoryRepository.findByConversationId(conversationId);
assertThat(results.size()).isEqualTo(messages.size());
assertThat(results).isEqualTo(messages);
}
@Test
void deleteMessagesByConversationId() {
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
chatMemoryRepository.deleteByConversationId(conversationId);
var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?",
Integer.class, conversationId);
assertThat(count).isZero();
}
@Test
void testMessageOrder() {
// Create a repository using the from method to detect the dialect
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()))
.build();
var conversationId = UUID.randomUUID().toString();
// Create messages with very distinct content to make order obvious
var firstMessage = new UserMessage("1-First message");
var secondMessage = new AssistantMessage("2-Second message");
var thirdMessage = new UserMessage("3-Third message");
var fourthMessage = new SystemMessage("4-Fourth message");
// Save messages in the expected order
List<Message> orderedMessages = List.of(firstMessage, secondMessage, thirdMessage, fourthMessage);
repository.saveAll(conversationId, orderedMessages);
// Retrieve messages using the repository
List<Message> retrievedMessages = repository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(4);
// Get the actual order from the retrieved messages
List<String> retrievedContents = retrievedMessages.stream().map(Message::getText).collect(Collectors.toList());
// Messages should be in the original order (ASC)
assertThat(retrievedContents).containsExactly("1-First message", "2-Second message", "3-Third message",
"4-Fourth message");
}
/**
* Base configuration for all integration tests.
*/
@ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class })
static abstract class BaseTestConfiguration {
@Bean
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) {
return JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(dataSource))
.build();
}
}
}

View File

@@ -0,0 +1,41 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.memory.repository.jdbc;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.jdbc.Sql;
/**
* Integration tests for {@link JdbcChatMemoryRepository} with MySQL.
*
* @author Jonathan Leijendekker
* @author Thomas Vitale
* @author Mark Pollack
*/
@SpringBootTest(classes = JdbcChatMemoryRepositoryMysqlIT.TestConfiguration.class)
@TestPropertySource(properties = { "spring.datasource.url=jdbc:tc:mariadb:10.3.39:///" })
@Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-mariadb.sql")
class JdbcChatMemoryRepositoryMysqlIT extends AbstractJdbcChatMemoryRepositoryIT {
@SpringBootConfiguration
static class TestConfiguration extends BaseTestConfiguration {
}
}

View File

@@ -16,147 +16,38 @@
package org.springframework.ai.chat.memory.repository.jdbc;
import java.sql.Timestamp;
import java.util.List;
import java.util.UUID;
import javax.sql.DataSource;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.transaction.support.TransactionTemplate;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Integration tests for {@link JdbcChatMemoryRepository}.
* Integration tests for {@link JdbcChatMemoryRepository} with PostgreSQL.
*
* @author Jonathan Leijendekker
* @author Thomas Vitale
* @author Mark Pollack
*/
@SpringBootTest(classes = JdbcChatMemoryRepositoryPostgresqlIT.TestConfiguration.class)
@TestPropertySource(properties = "spring.datasource.url=jdbc:tc:postgresql:17:///")
@Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-postgresql.sql")
class JdbcChatMemoryRepositoryPostgresqlIT {
@Autowired
private ChatMemoryRepository chatMemoryRepository;
@Autowired
private JdbcTemplate jdbcTemplate;
@Test
void correctChatMemoryRepositoryInstance() {
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
}
@ParameterizedTest
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" })
void saveMessagesSingleMessage(String content, MessageType messageType) {
var conversationId = UUID.randomUUID().toString();
var message = switch (messageType) {
case ASSISTANT -> new AssistantMessage(content + " - " + conversationId);
case USER -> new UserMessage(content + " - " + conversationId);
case SYSTEM -> new SystemMessage(content + " - " + conversationId);
default -> throw new IllegalArgumentException("Type not supported: " + messageType);
};
chatMemoryRepository.saveAll(conversationId, List.of(message));
var query = "SELECT conversation_id, content, type, \"timestamp\" FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
var result = jdbcTemplate.queryForMap(query, conversationId);
assertThat(result.size()).isEqualTo(4);
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
assertThat(result.get("content")).isEqualTo(message.getText());
assertThat(result.get("type")).isEqualTo(messageType.name());
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
}
@Test
void saveMessagesMultipleMessages() {
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
var query = "SELECT conversation_id, content, type, \"timestamp\" FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
var results = jdbcTemplate.queryForList(query, conversationId);
assertThat(results.size()).isEqualTo(messages.size());
for (var i = 0; i < messages.size(); i++) {
var message = messages.get(i);
var result = results.get(i);
assertThat(result.get("conversation_id")).isNotNull();
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
assertThat(result.get("content")).isEqualTo(message.getText());
assertThat(result.get("type")).isEqualTo(message.getMessageType().name());
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
}
var count = chatMemoryRepository.findByConversationId(conversationId).size();
assertThat(count).isEqualTo(messages.size());
chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
count = chatMemoryRepository.findByConversationId(conversationId).size();
assertThat(count).isEqualTo(1);
}
@Test
void findMessagesByConversationId() {
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
new AssistantMessage("Message from assistant 2 - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
var results = chatMemoryRepository.findByConversationId(conversationId);
assertThat(results.size()).isEqualTo(messages.size());
assertThat(results).isEqualTo(messages);
}
@Test
void deleteMessagesByConversationId() {
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
chatMemoryRepository.deleteByConversationId(conversationId);
var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?",
Integer.class, conversationId);
assertThat(count).isZero();
}
class JdbcChatMemoryRepositoryPostgresqlIT extends AbstractJdbcChatMemoryRepositoryIT {
@Test
void repositoryWithExplicitTransactionManager() {
@@ -187,16 +78,7 @@ class JdbcChatMemoryRepositoryPostgresqlIT {
}
@SpringBootConfiguration
@ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class })
static class TestConfiguration {
@Bean
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) {
return JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(dataSource))
.build();
}
static class TestConfiguration extends BaseTestConfiguration {
@Bean
ChatMemoryRepository chatMemoryRepositoryWithTxManager(JdbcTemplate jdbcTemplate, DataSource dataSource) {