Fix VectorStoreChatMemoryAdvisor streaming bug
- Override adviseStream method in VectorStoreChatMemoryAdvisor to properly handle streaming responses - Add tests to verify the fix works with both normal and problematic streaming scenarios Fixes #3152 Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
This commit is contained in:
@@ -21,7 +21,10 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
|
||||
import org.springframework.util.Assert;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
@@ -30,10 +33,12 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.document.Document;
|
||||
@@ -167,6 +172,20 @@ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
// Get the scheduler from BaseAdvisor
|
||||
Scheduler scheduler = this.getScheduler();
|
||||
// Process the request with the before method
|
||||
return Mono.just(chatClientRequest)
|
||||
.publishOn(scheduler)
|
||||
.map(request -> this.before(request, streamAdvisorChain))
|
||||
.flatMapMany(streamAdvisorChain::nextStream)
|
||||
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
|
||||
response -> this.after(response, streamAdvisorChain)));
|
||||
}
|
||||
|
||||
private List<Document> toDocuments(List<Message> messages, String conversationId) {
|
||||
List<Document> docs = messages.stream()
|
||||
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
|
||||
|
||||
@@ -22,6 +22,7 @@ import java.util.List;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
|
||||
@@ -412,4 +413,74 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
assertThat(memoryMessages.get(6).getText()).isEqualTo("What is my name and where do I live?");
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests that the advisor correctly handles streaming responses and updates the
|
||||
* memory. This verifies that the adviseStream method in chat memory advisors is
|
||||
* working correctly.
|
||||
*/
|
||||
protected void testStreamingWithChatMemory() {
|
||||
// Arrange
|
||||
String conversationId = "streaming-conversation-" + System.currentTimeMillis();
|
||||
ChatMemory chatMemory = MessageWindowChatMemory.builder()
|
||||
.chatMemoryRepository(new InMemoryChatMemoryRepository())
|
||||
.build();
|
||||
|
||||
// Create advisor with the conversation ID
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Act - Send a message using streaming
|
||||
String initialQuestion = "My name is David and I live in Seattle.";
|
||||
|
||||
// Collect all streaming chunks
|
||||
List<String> streamingChunks = new ArrayList<>();
|
||||
Flux<String> responseStream = chatClient.prompt()
|
||||
.user(initialQuestion)
|
||||
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
// Block and collect all streaming chunks
|
||||
responseStream.doOnNext(streamingChunks::add).blockLast();
|
||||
|
||||
// Join all chunks to get the complete response
|
||||
String completeResponse = String.join("", streamingChunks);
|
||||
|
||||
logger.info("Streaming response: {}", completeResponse);
|
||||
|
||||
// Verify memory contains the initial question and the response
|
||||
List<Message> memoryMessages = chatMemory.get(conversationId);
|
||||
assertThat(memoryMessages).hasSize(2); // 1 user message + 1 assistant response
|
||||
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
|
||||
|
||||
// Send a follow-up question using streaming
|
||||
String followUpQuestion = "Where do I live?";
|
||||
|
||||
// Collect all streaming chunks for the follow-up
|
||||
List<String> followUpStreamingChunks = new ArrayList<>();
|
||||
Flux<String> followUpResponseStream = chatClient.prompt()
|
||||
.user(followUpQuestion)
|
||||
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
// Block and collect all streaming chunks
|
||||
followUpResponseStream.doOnNext(followUpStreamingChunks::add).blockLast();
|
||||
|
||||
// Join all chunks to get the complete follow-up response
|
||||
String followUpCompleteResponse = String.join("", followUpStreamingChunks);
|
||||
|
||||
logger.info("Follow-up streaming response: {}", followUpCompleteResponse);
|
||||
|
||||
// Verify the follow-up response contains the location
|
||||
assertThat(followUpCompleteResponse).containsIgnoringCase("Seattle");
|
||||
|
||||
// Verify memory now contains all messages
|
||||
memoryMessages = chatMemory.get(conversationId);
|
||||
assertThat(memoryMessages).hasSize(4); // 2 user messages + 2 assistant responses
|
||||
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
|
||||
assertThat(memoryMessages.get(2).getText()).isEqualTo(followUpQuestion);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -190,4 +190,9 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
|
||||
logger.info("Assistant response stored in memory: {}", assistantMessage.getText());
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleStreamingWithChatMemory() {
|
||||
testStreamingWithChatMemory();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -135,4 +135,9 @@ public class PromptChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
|
||||
testMultipleUserMessagesInPrompt();
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldHandleStreamingWithChatMemory() {
|
||||
testStreamingWithChatMemory();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import org.postgresql.ds.PGSimpleDataSource;
|
||||
import org.testcontainers.containers.PostgreSQLContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor;
|
||||
@@ -42,9 +43,11 @@ import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.fail;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.mock;
|
||||
@@ -117,6 +120,78 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
|
||||
""");
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a mock ChatModel that supports streaming responses for testing.
|
||||
* @return A mock ChatModel that returns a predefined streaming response
|
||||
*/
|
||||
private static @NotNull ChatModel chatModelWithStreamingSupport() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
|
||||
// Mock the regular call method
|
||||
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
|
||||
Why don't scientists trust atoms?
|
||||
Because they make up everything!
|
||||
"""))));
|
||||
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
|
||||
|
||||
// Mock the streaming method
|
||||
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
Flux<ChatResponse> streamingResponse = Flux.just(
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))));
|
||||
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a mock ChatModel that simulates the problematic streaming behavior. This
|
||||
* mock includes a final empty message that triggers the bug in
|
||||
* VectorStoreChatMemoryAdvisor.
|
||||
* @return A mock ChatModel that returns a problematic streaming response
|
||||
*/
|
||||
private static @NotNull ChatModel chatModelWithProblematicStreamingBehavior() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
|
||||
// Mock the regular call method
|
||||
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
|
||||
Why don't scientists trust atoms?
|
||||
Because they make up everything!
|
||||
"""))));
|
||||
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
|
||||
|
||||
// Mock the streaming method with a problematic final message (empty content)
|
||||
// This simulates the real-world condition that triggers the bug
|
||||
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
Flux<ChatResponse> streamingResponse = Flux.just(
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))),
|
||||
// This final empty message triggers the bug in
|
||||
// VectorStoreChatMemoryAdvisor
|
||||
new ChatResponse(List.of(new Generation(new AssistantMessage("")))));
|
||||
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
|
||||
* messages from the (gp)vector store.
|
||||
@@ -182,6 +257,139 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
|
||||
""");
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that streaming chats with {@link VectorStoreChatMemoryAdvisor} get advised
|
||||
* with similar messages from the vector store and properly handle streaming
|
||||
* responses.
|
||||
*
|
||||
* This test verifies that the fix for the bug reported in
|
||||
* https://github.com/spring-projects/spring-ai/issues/3152 works correctly. The
|
||||
* VectorStoreChatMemoryAdvisor now properly handles streaming responses and saves the
|
||||
* assistant's messages to the vector store.
|
||||
*/
|
||||
@Test
|
||||
void advisedStreamingChatShouldHaveSimilarMessagesFromVectorStore() throws Exception {
|
||||
// Create a ChatModel with streaming support
|
||||
ChatModel chatModel = chatModelWithStreamingSupport();
|
||||
|
||||
// Create the embedding model
|
||||
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
|
||||
|
||||
// Create and initialize the vector store
|
||||
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
|
||||
String conversationId = UUID.randomUUID().toString();
|
||||
initStore(store, conversationId);
|
||||
|
||||
// Create a chat client with the VectorStoreChatMemoryAdvisor
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).build();
|
||||
|
||||
// Execute a streaming chat request
|
||||
Flux<String> responseStream = chatClient.prompt()
|
||||
.user("joke")
|
||||
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
|
||||
.param(ChatMemory.CONVERSATION_ID, conversationId))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
// Collect all streaming chunks
|
||||
List<String> streamingChunks = responseStream.collectList().block();
|
||||
|
||||
// Verify the streaming response
|
||||
assertThat(streamingChunks).isNotNull();
|
||||
String completeResponse = String.join("", streamingChunks);
|
||||
assertThat(completeResponse).contains("scientists", "atoms", "everything");
|
||||
|
||||
// Verify the request was properly advised with vector store content
|
||||
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
verify(chatModel).stream(promptCaptor.capture());
|
||||
Prompt capturedPrompt = promptCaptor.getValue();
|
||||
assertThat(capturedPrompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(capturedPrompt.getInstructions().get(0).getText()).isEqualToIgnoringWhitespace("""
|
||||
|
||||
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
|
||||
|
||||
---------------------
|
||||
LONG_TERM_MEMORY:
|
||||
Tell me a good joke
|
||||
Tell me a bad joke
|
||||
---------------------
|
||||
""");
|
||||
|
||||
// Verify that the assistant's response was properly added to the vector store
|
||||
// after
|
||||
// streaming completed
|
||||
// This verifies that the fix for the adviseStream implementation works correctly
|
||||
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
|
||||
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
|
||||
|
||||
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
|
||||
|
||||
// With our fix, the assistant's response should be saved to the vector store
|
||||
assertThat(assistantDocuments).isNotEmpty();
|
||||
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that verifies the fix for the bug reported in
|
||||
* https://github.com/spring-projects/spring-ai/issues/3152. The
|
||||
* VectorStoreChatMemoryAdvisor now properly handles streaming responses with empty
|
||||
* messages by using ChatClientMessageAggregator to aggregate messages before calling
|
||||
* the after method.
|
||||
*/
|
||||
@Test
|
||||
void vectorStoreChatMemoryAdvisorShouldHandleEmptyMessagesInStream() throws Exception {
|
||||
// Create a ChatModel with problematic streaming behavior
|
||||
ChatModel chatModel = chatModelWithProblematicStreamingBehavior();
|
||||
|
||||
// Create the embedding model
|
||||
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
|
||||
|
||||
// Create and initialize the vector store
|
||||
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
|
||||
String conversationId = UUID.randomUUID().toString();
|
||||
initStore(store, conversationId);
|
||||
|
||||
// Create a chat client with the VectorStoreChatMemoryAdvisor
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).build();
|
||||
|
||||
// Execute a streaming chat request
|
||||
// This should now succeed with our fix
|
||||
Flux<String> responseStream = chatClient.prompt()
|
||||
.user("joke")
|
||||
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
|
||||
.param(ChatMemory.CONVERSATION_ID, conversationId))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
// Collect all streaming chunks - this should no longer throw an exception
|
||||
List<String> streamingChunks = responseStream.collectList().block();
|
||||
|
||||
// Verify the streaming response
|
||||
assertThat(streamingChunks).isNotNull();
|
||||
String completeResponse = String.join("", streamingChunks);
|
||||
assertThat(completeResponse).contains("scientists", "atoms", "everything");
|
||||
|
||||
// Verify that the assistant's response was properly added to the vector store
|
||||
// This verifies that our fix works correctly
|
||||
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
|
||||
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
|
||||
|
||||
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
|
||||
assertThat(assistantDocuments).isNotEmpty();
|
||||
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to get the root cause of an exception
|
||||
*/
|
||||
private Throwable getRootCause(Throwable throwable) {
|
||||
Throwable cause = throwable;
|
||||
while (cause.getCause() != null && cause.getCause() != cause) {
|
||||
cause = cause.getCause();
|
||||
}
|
||||
return cause;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
|
||||
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);
|
||||
|
||||
Reference in New Issue
Block a user