diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java index 14d8c2d13..f7c4b96cd 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java @@ -53,6 +53,7 @@ import org.springframework.util.Assert; * @author Thomas Vitale * @author Linar Abzaltdinov * @author Mark Pollack + * @author Yanming Zhou * @since 1.0.0 */ public final class JdbcChatMemoryRepository implements ChatMemoryRepository { @@ -65,14 +66,14 @@ public final class JdbcChatMemoryRepository implements ChatMemoryRepository { private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryRepository.class); - private JdbcChatMemoryRepository(DataSource dataSource, JdbcChatMemoryRepositoryDialect dialect, + private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect, PlatformTransactionManager txManager) { - Assert.notNull(dataSource, "dataSource cannot be null"); + Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null"); Assert.notNull(dialect, "dialect cannot be null"); - this.jdbcTemplate = new JdbcTemplate(dataSource); + this.jdbcTemplate = jdbcTemplate; this.dialect = dialect; this.transactionTemplate = new TransactionTemplate( - txManager != null ? txManager : new DataSourceTransactionManager(dataSource)); + txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource())); } @Override @@ -192,7 +193,18 @@ public final class JdbcChatMemoryRepository implements ChatMemoryRepository { public JdbcChatMemoryRepository build() { DataSource effectiveDataSource = resolveDataSource(); JdbcChatMemoryRepositoryDialect effectiveDialect = resolveDialect(effectiveDataSource); - return new JdbcChatMemoryRepository(effectiveDataSource, effectiveDialect, this.platformTransactionManager); + return new JdbcChatMemoryRepository(resolveJdbcTemplate(), effectiveDialect, + this.platformTransactionManager); + } + + private JdbcTemplate resolveJdbcTemplate() { + if (this.jdbcTemplate != null) { + return this.jdbcTemplate; + } + if (this.dataSource != null) { + return new JdbcTemplate(this.dataSource); + } + throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())"); } private DataSource resolveDataSource() { diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java index a6f895154..87dbfa5f0 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java @@ -24,6 +24,7 @@ import javax.sql.DataSource; import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.transaction.PlatformTransactionManager; import static org.assertj.core.api.Assertions.assertThat; @@ -35,6 +36,7 @@ import static org.mockito.Mockito.when; * Tests for {@link JdbcChatMemoryRepository.Builder}. * * @author Mark Pollack + * @author Yanming Zhou */ public class JdbcChatMemoryRepositoryBuilderTests { @@ -224,4 +226,14 @@ public class JdbcChatMemoryRepositoryBuilderTests { // for this) } + @Test + void repositoryShouldUseProvidedJdbcTemplate() throws SQLException { + DataSource dataSource = mock(DataSource.class); + JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + + JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build(); + + assertThat(repository).extracting("jdbcTemplate").isSameAs(jdbcTemplate); + } + }