diff --git a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java index 16bbe97f3..146b14e44 100644 --- a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java +++ b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java @@ -21,7 +21,10 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.client.ChatClientMessageAggregator; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import org.springframework.ai.chat.client.ChatClientRequest; @@ -30,10 +33,12 @@ import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AdvisorChain; import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; @@ -167,6 +172,20 @@ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor { return chatClientResponse; } + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + // Get the scheduler from BaseAdvisor + Scheduler scheduler = this.getScheduler(); + // Process the request with the before method + return Mono.just(chatClientRequest) + .publishOn(scheduler) + .map(request -> this.before(request, streamAdvisorChain)) + .flatMapMany(streamAdvisorChain::nextStream) + .transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, + response -> this.after(response, streamAdvisorChain))); + } + private List toDocuments(List messages, String conversationId) { List docs = messages.stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/AbstractChatMemoryAdvisorIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/AbstractChatMemoryAdvisorIT.java index 238ceb171..b18fe8200 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/AbstractChatMemoryAdvisorIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/AbstractChatMemoryAdvisorIT.java @@ -22,6 +22,7 @@ import java.util.List; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor; @@ -412,4 +413,74 @@ public abstract class AbstractChatMemoryAdvisorIT { assertThat(memoryMessages.get(6).getText()).isEqualTo("What is my name and where do I live?"); } + /** + * Tests that the advisor correctly handles streaming responses and updates the + * memory. This verifies that the adviseStream method in chat memory advisors is + * working correctly. + */ + protected void testStreamingWithChatMemory() { + // Arrange + String conversationId = "streaming-conversation-" + System.currentTimeMillis(); + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + // Create advisor with the conversation ID + var advisor = createAdvisor(chatMemory); + + ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build(); + + // Act - Send a message using streaming + String initialQuestion = "My name is David and I live in Seattle."; + + // Collect all streaming chunks + List streamingChunks = new ArrayList<>(); + Flux responseStream = chatClient.prompt() + .user(initialQuestion) + .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)) + .stream() + .content(); + + // Block and collect all streaming chunks + responseStream.doOnNext(streamingChunks::add).blockLast(); + + // Join all chunks to get the complete response + String completeResponse = String.join("", streamingChunks); + + logger.info("Streaming response: {}", completeResponse); + + // Verify memory contains the initial question and the response + List memoryMessages = chatMemory.get(conversationId); + assertThat(memoryMessages).hasSize(2); // 1 user message + 1 assistant response + assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion); + + // Send a follow-up question using streaming + String followUpQuestion = "Where do I live?"; + + // Collect all streaming chunks for the follow-up + List followUpStreamingChunks = new ArrayList<>(); + Flux followUpResponseStream = chatClient.prompt() + .user(followUpQuestion) + .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)) + .stream() + .content(); + + // Block and collect all streaming chunks + followUpResponseStream.doOnNext(followUpStreamingChunks::add).blockLast(); + + // Join all chunks to get the complete follow-up response + String followUpCompleteResponse = String.join("", followUpStreamingChunks); + + logger.info("Follow-up streaming response: {}", followUpCompleteResponse); + + // Verify the follow-up response contains the location + assertThat(followUpCompleteResponse).containsIgnoringCase("Seattle"); + + // Verify memory now contains all messages + memoryMessages = chatMemory.get(conversationId); + assertThat(memoryMessages).hasSize(4); // 2 user messages + 2 assistant responses + assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion); + assertThat(memoryMessages.get(2).getText()).isEqualTo(followUpQuestion); + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java index 45f63ac71..392ae0377 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java @@ -190,4 +190,9 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT { logger.info("Assistant response stored in memory: {}", assistantMessage.getText()); } + @Test + void shouldHandleStreamingWithChatMemory() { + testStreamingWithChatMemory(); + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/PromptChatMemoryAdvisorIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/PromptChatMemoryAdvisorIT.java index 6f0da2c87..cfa871235 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/PromptChatMemoryAdvisorIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/PromptChatMemoryAdvisorIT.java @@ -135,4 +135,9 @@ public class PromptChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT { testMultipleUserMessagesInPrompt(); } + @Test + void shouldHandleStreamingWithChatMemory() { + testStreamingWithChatMemory(); + } + } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java index 146aed686..40c315c42 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -30,6 +30,7 @@ import org.postgresql.ds.PGSimpleDataSource; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor; @@ -42,9 +43,11 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.jdbc.core.JdbcTemplate; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -117,6 +120,78 @@ class PgVectorStoreWithChatMemoryAdvisorIT { """); } + /** + * Create a mock ChatModel that supports streaming responses for testing. + * @return A mock ChatModel that returns a predefined streaming response + */ + private static @NotNull ChatModel chatModelWithStreamingSupport() { + ChatModel chatModel = mock(ChatModel.class); + + // Mock the regular call method + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Prompt.class); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" + Why don't scientists trust atoms? + Because they make up everything! + """)))); + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); + + // Mock the streaming method + ArgumentCaptor streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class); + Flux streamingResponse = Flux.just( + new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))), + new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!"))))); + given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse); + + return chatModel; + } + + /** + * Create a mock ChatModel that simulates the problematic streaming behavior. This + * mock includes a final empty message that triggers the bug in + * VectorStoreChatMemoryAdvisor. + * @return A mock ChatModel that returns a problematic streaming response + */ + private static @NotNull ChatModel chatModelWithProblematicStreamingBehavior() { + ChatModel chatModel = mock(ChatModel.class); + + // Mock the regular call method + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Prompt.class); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" + Why don't scientists trust atoms? + Because they make up everything! + """)))); + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); + + // Mock the streaming method with a problematic final message (empty content) + // This simulates the real-world condition that triggers the bug + ArgumentCaptor streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class); + Flux streamingResponse = Flux.just( + new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))), + new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))), + // This final empty message triggers the bug in + // VectorStoreChatMemoryAdvisor + new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); + given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse); + + return chatModel; + } + /** * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar * messages from the (gp)vector store. @@ -182,6 +257,139 @@ class PgVectorStoreWithChatMemoryAdvisorIT { """); } + /** + * Test that streaming chats with {@link VectorStoreChatMemoryAdvisor} get advised + * with similar messages from the vector store and properly handle streaming + * responses. + * + * This test verifies that the fix for the bug reported in + * https://github.com/spring-projects/spring-ai/issues/3152 works correctly. The + * VectorStoreChatMemoryAdvisor now properly handles streaming responses and saves the + * assistant's messages to the vector store. + */ + @Test + void advisedStreamingChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { + // Create a ChatModel with streaming support + ChatModel chatModel = chatModelWithStreamingSupport(); + + // Create the embedding model + EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); + + // Create and initialize the vector store + PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); + String conversationId = UUID.randomUUID().toString(); + initStore(store, conversationId); + + // Create a chat client with the VectorStoreChatMemoryAdvisor + ChatClient chatClient = ChatClient.builder(chatModel).build(); + + // Execute a streaming chat request + Flux responseStream = chatClient.prompt() + .user("joke") + .advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .param(ChatMemory.CONVERSATION_ID, conversationId)) + .stream() + .content(); + + // Collect all streaming chunks + List streamingChunks = responseStream.collectList().block(); + + // Verify the streaming response + assertThat(streamingChunks).isNotNull(); + String completeResponse = String.join("", streamingChunks); + assertThat(completeResponse).contains("scientists", "atoms", "everything"); + + // Verify the request was properly advised with vector store content + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + verify(chatModel).stream(promptCaptor.capture()); + Prompt capturedPrompt = promptCaptor.getValue(); + assertThat(capturedPrompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(capturedPrompt.getInstructions().get(0).getText()).isEqualToIgnoringWhitespace(""" + + Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. + + --------------------- + LONG_TERM_MEMORY: + Tell me a good joke + Tell me a bad joke + --------------------- + """); + + // Verify that the assistant's response was properly added to the vector store + // after + // streaming completed + // This verifies that the fix for the adviseStream implementation works correctly + String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'"; + var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build(); + + List assistantDocuments = store.similaritySearch(searchRequest); + + // With our fix, the assistant's response should be saved to the vector store + assertThat(assistantDocuments).isNotEmpty(); + assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything"); + } + + /** + * Test that verifies the fix for the bug reported in + * https://github.com/spring-projects/spring-ai/issues/3152. The + * VectorStoreChatMemoryAdvisor now properly handles streaming responses with empty + * messages by using ChatClientMessageAggregator to aggregate messages before calling + * the after method. + */ + @Test + void vectorStoreChatMemoryAdvisorShouldHandleEmptyMessagesInStream() throws Exception { + // Create a ChatModel with problematic streaming behavior + ChatModel chatModel = chatModelWithProblematicStreamingBehavior(); + + // Create the embedding model + EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); + + // Create and initialize the vector store + PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); + String conversationId = UUID.randomUUID().toString(); + initStore(store, conversationId); + + // Create a chat client with the VectorStoreChatMemoryAdvisor + ChatClient chatClient = ChatClient.builder(chatModel).build(); + + // Execute a streaming chat request + // This should now succeed with our fix + Flux responseStream = chatClient.prompt() + .user("joke") + .advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .param(ChatMemory.CONVERSATION_ID, conversationId)) + .stream() + .content(); + + // Collect all streaming chunks - this should no longer throw an exception + List streamingChunks = responseStream.collectList().block(); + + // Verify the streaming response + assertThat(streamingChunks).isNotNull(); + String completeResponse = String.join("", streamingChunks); + assertThat(completeResponse).contains("scientists", "atoms", "everything"); + + // Verify that the assistant's response was properly added to the vector store + // This verifies that our fix works correctly + String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'"; + var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build(); + + List assistantDocuments = store.similaritySearch(searchRequest); + assertThat(assistantDocuments).isNotEmpty(); + assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything"); + } + + /** + * Helper method to get the root cause of an exception + */ + private Throwable getRootCause(Throwable throwable) { + Throwable cause = throwable; + while (cause.getCause() != null && cause.getCause() != cause) { + cause = cause.getCause(); + } + return cause; + } + @SuppressWarnings("unchecked") private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() { EmbeddingModel embeddingModel = mock(EmbeddingModel.class);