Fix VectorStoreChatMemoryAdvisor streaming bug

- Override adviseStream method in VectorStoreChatMemoryAdvisor to properly handle streaming responses
- Add tests to verify the fix works with both normal and problematic streaming scenarios

Fixes #3152

Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
This commit is contained in:
Mark Pollack
2025-05-15 17:50:27 -04:00
parent 008a760fa7
commit 867cc302cb
5 changed files with 308 additions and 0 deletions

View File

@@ -21,7 +21,10 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import org.springframework.ai.chat.client.ChatClientRequest;
@@ -30,10 +33,12 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
@@ -167,6 +172,20 @@ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
return chatClientResponse;
}
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
// Get the scheduler from BaseAdvisor
Scheduler scheduler = this.getScheduler();
// Process the request with the before method
return Mono.just(chatClientRequest)
.publishOn(scheduler)
.map(request -> this.before(request, streamAdvisorChain))
.flatMapMany(streamAdvisorChain::nextStream)
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
response -> this.after(response, streamAdvisorChain)));
}
private List<Document> toDocuments(List<Message> messages, String conversationId) {
List<Document> docs = messages.stream()
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)

View File

@@ -22,6 +22,7 @@ import java.util.List;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
@@ -412,4 +413,74 @@ public abstract class AbstractChatMemoryAdvisorIT {
assertThat(memoryMessages.get(6).getText()).isEqualTo("What is my name and where do I live?");
}
/**
* Tests that the advisor correctly handles streaming responses and updates the
* memory. This verifies that the adviseStream method in chat memory advisors is
* working correctly.
*/
protected void testStreamingWithChatMemory() {
// Arrange
String conversationId = "streaming-conversation-" + System.currentTimeMillis();
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();
// Create advisor with the conversation ID
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
// Act - Send a message using streaming
String initialQuestion = "My name is David and I live in Seattle.";
// Collect all streaming chunks
List<String> streamingChunks = new ArrayList<>();
Flux<String> responseStream = chatClient.prompt()
.user(initialQuestion)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.content();
// Block and collect all streaming chunks
responseStream.doOnNext(streamingChunks::add).blockLast();
// Join all chunks to get the complete response
String completeResponse = String.join("", streamingChunks);
logger.info("Streaming response: {}", completeResponse);
// Verify memory contains the initial question and the response
List<Message> memoryMessages = chatMemory.get(conversationId);
assertThat(memoryMessages).hasSize(2); // 1 user message + 1 assistant response
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
// Send a follow-up question using streaming
String followUpQuestion = "Where do I live?";
// Collect all streaming chunks for the follow-up
List<String> followUpStreamingChunks = new ArrayList<>();
Flux<String> followUpResponseStream = chatClient.prompt()
.user(followUpQuestion)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.content();
// Block and collect all streaming chunks
followUpResponseStream.doOnNext(followUpStreamingChunks::add).blockLast();
// Join all chunks to get the complete follow-up response
String followUpCompleteResponse = String.join("", followUpStreamingChunks);
logger.info("Follow-up streaming response: {}", followUpCompleteResponse);
// Verify the follow-up response contains the location
assertThat(followUpCompleteResponse).containsIgnoringCase("Seattle");
// Verify memory now contains all messages
memoryMessages = chatMemory.get(conversationId);
assertThat(memoryMessages).hasSize(4); // 2 user messages + 2 assistant responses
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
assertThat(memoryMessages.get(2).getText()).isEqualTo(followUpQuestion);
}
}

View File

@@ -190,4 +190,9 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
logger.info("Assistant response stored in memory: {}", assistantMessage.getText());
}
@Test
void shouldHandleStreamingWithChatMemory() {
testStreamingWithChatMemory();
}
}

View File

@@ -135,4 +135,9 @@ public class PromptChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
testMultipleUserMessagesInPrompt();
}
@Test
void shouldHandleStreamingWithChatMemory() {
testStreamingWithChatMemory();
}
}

View File

@@ -30,6 +30,7 @@ import org.postgresql.ds.PGSimpleDataSource;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor;
@@ -42,9 +43,11 @@ import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.jdbc.core.JdbcTemplate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@@ -117,6 +120,78 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
""");
}
/**
* Create a mock ChatModel that supports streaming responses for testing.
* @return A mock ChatModel that returns a predefined streaming response
*/
private static @NotNull ChatModel chatModelWithStreamingSupport() {
ChatModel chatModel = mock(ChatModel.class);
// Mock the regular call method
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
Why don't scientists trust atoms?
Because they make up everything!
"""))));
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
// Mock the streaming method
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
Flux<ChatResponse> streamingResponse = Flux.just(
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))));
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
return chatModel;
}
/**
* Create a mock ChatModel that simulates the problematic streaming behavior. This
* mock includes a final empty message that triggers the bug in
* VectorStoreChatMemoryAdvisor.
* @return A mock ChatModel that returns a problematic streaming response
*/
private static @NotNull ChatModel chatModelWithProblematicStreamingBehavior() {
ChatModel chatModel = mock(ChatModel.class);
// Mock the regular call method
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
Why don't scientists trust atoms?
Because they make up everything!
"""))));
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
// Mock the streaming method with a problematic final message (empty content)
// This simulates the real-world condition that triggers the bug
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
Flux<ChatResponse> streamingResponse = Flux.just(
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))),
// This final empty message triggers the bug in
// VectorStoreChatMemoryAdvisor
new ChatResponse(List.of(new Generation(new AssistantMessage("")))));
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
return chatModel;
}
/**
* Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
* messages from the (gp)vector store.
@@ -182,6 +257,139 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
""");
}
/**
* Test that streaming chats with {@link VectorStoreChatMemoryAdvisor} get advised
* with similar messages from the vector store and properly handle streaming
* responses.
*
* This test verifies that the fix for the bug reported in
* https://github.com/spring-projects/spring-ai/issues/3152 works correctly. The
* VectorStoreChatMemoryAdvisor now properly handles streaming responses and saves the
* assistant's messages to the vector store.
*/
@Test
void advisedStreamingChatShouldHaveSimilarMessagesFromVectorStore() throws Exception {
// Create a ChatModel with streaming support
ChatModel chatModel = chatModelWithStreamingSupport();
// Create the embedding model
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
// Create and initialize the vector store
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
String conversationId = UUID.randomUUID().toString();
initStore(store, conversationId);
// Create a chat client with the VectorStoreChatMemoryAdvisor
ChatClient chatClient = ChatClient.builder(chatModel).build();
// Execute a streaming chat request
Flux<String> responseStream = chatClient.prompt()
.user("joke")
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.content();
// Collect all streaming chunks
List<String> streamingChunks = responseStream.collectList().block();
// Verify the streaming response
assertThat(streamingChunks).isNotNull();
String completeResponse = String.join("", streamingChunks);
assertThat(completeResponse).contains("scientists", "atoms", "everything");
// Verify the request was properly advised with vector store content
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
verify(chatModel).stream(promptCaptor.capture());
Prompt capturedPrompt = promptCaptor.getValue();
assertThat(capturedPrompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class);
assertThat(capturedPrompt.getInstructions().get(0).getText()).isEqualToIgnoringWhitespace("""
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
---------------------
LONG_TERM_MEMORY:
Tell me a good joke
Tell me a bad joke
---------------------
""");
// Verify that the assistant's response was properly added to the vector store
// after
// streaming completed
// This verifies that the fix for the adviseStream implementation works correctly
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
// With our fix, the assistant's response should be saved to the vector store
assertThat(assistantDocuments).isNotEmpty();
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
}
/**
* Test that verifies the fix for the bug reported in
* https://github.com/spring-projects/spring-ai/issues/3152. The
* VectorStoreChatMemoryAdvisor now properly handles streaming responses with empty
* messages by using ChatClientMessageAggregator to aggregate messages before calling
* the after method.
*/
@Test
void vectorStoreChatMemoryAdvisorShouldHandleEmptyMessagesInStream() throws Exception {
// Create a ChatModel with problematic streaming behavior
ChatModel chatModel = chatModelWithProblematicStreamingBehavior();
// Create the embedding model
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
// Create and initialize the vector store
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
String conversationId = UUID.randomUUID().toString();
initStore(store, conversationId);
// Create a chat client with the VectorStoreChatMemoryAdvisor
ChatClient chatClient = ChatClient.builder(chatModel).build();
// Execute a streaming chat request
// This should now succeed with our fix
Flux<String> responseStream = chatClient.prompt()
.user("joke")
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
.param(ChatMemory.CONVERSATION_ID, conversationId))
.stream()
.content();
// Collect all streaming chunks - this should no longer throw an exception
List<String> streamingChunks = responseStream.collectList().block();
// Verify the streaming response
assertThat(streamingChunks).isNotNull();
String completeResponse = String.join("", streamingChunks);
assertThat(completeResponse).contains("scientists", "atoms", "everything");
// Verify that the assistant's response was properly added to the vector store
// This verifies that our fix works correctly
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
assertThat(assistantDocuments).isNotEmpty();
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
}
/**
* Helper method to get the root cause of an exception
*/
private Throwable getRootCause(Throwable throwable) {
Throwable cause = throwable;
while (cause.getCause() != null && cause.getCause() != cause) {
cause = cause.getCause();
}
return cause;
}
@SuppressWarnings("unchecked")
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);