Address some review comments
This commit is contained in:
@@ -49,9 +49,9 @@ import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
@Testcontainers
|
||||
@SpringBootTest(classes = TextChatHistoryChatAgentIT.Config.class)
|
||||
@SpringBootTest(classes = ChatMemoryLongTermSystemPromptIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class TextChatHistoryChatAgentIT extends BaseMemoryTest {
|
||||
public class ChatMemoryLongTermSystemPromptIT extends BaseMemoryTest {
|
||||
|
||||
private static final String COLLECTION_NAME = "test_collection";
|
||||
|
||||
@@ -61,7 +61,7 @@ public class TextChatHistoryChatAgentIT extends BaseMemoryTest {
|
||||
static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.7.4");
|
||||
|
||||
@Autowired
|
||||
public TextChatHistoryChatAgentIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
|
||||
public ChatMemoryLongTermSystemPromptIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
|
||||
StreamingChatAgent streamingChatAgent) {
|
||||
super(relevancyEvaluator, chatAgent, streamingChatAgent);
|
||||
}
|
||||
@@ -103,7 +103,7 @@ public class TextChatHistoryChatAgentIT extends BaseMemoryTest {
|
||||
|
||||
return DefaultChatAgent.builder(chatClient)
|
||||
.withRetrievers(List.of(new VectorStoreChatMemoryRetriever(vectorStore, 10)))
|
||||
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
|
||||
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
|
||||
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
|
||||
.withChatAgentListeners(List.of(new VectorStoreChatMemoryAgentListener(vectorStore)))
|
||||
.build();
|
||||
@@ -40,12 +40,12 @@ import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
@SpringBootTest(classes = MessageChatHistoryChatAgentIT.Config.class)
|
||||
@SpringBootTest(classes = ChatMemoryShortTermMessageListIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class MessageChatHistoryChatAgentIT extends BaseMemoryTest {
|
||||
public class ChatMemoryShortTermMessageListIT extends BaseMemoryTest {
|
||||
|
||||
@Autowired
|
||||
public MessageChatHistoryChatAgentIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
|
||||
public ChatMemoryShortTermMessageListIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
|
||||
StreamingChatAgent streamingChatAgent) {
|
||||
super(relevancyEvaluator, chatAgent, streamingChatAgent);
|
||||
}
|
||||
@@ -79,7 +79,7 @@ public class MessageChatHistoryChatAgentIT extends BaseMemoryTest {
|
||||
|
||||
return DefaultChatAgent.builder(chatClient)
|
||||
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
|
||||
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
|
||||
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
|
||||
.withAugmentors(List.of(new MessageChatMemoryAugmentor()))
|
||||
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
|
||||
.build();
|
||||
@@ -41,12 +41,12 @@ import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
@SpringBootTest(classes = OpenAiMemoryChatAgentIT.Config.class)
|
||||
@SpringBootTest(classes = ChatMemoryShortTermSystemPromptIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class OpenAiMemoryChatAgentIT extends BaseMemoryTest {
|
||||
public class ChatMemoryShortTermSystemPromptIT extends BaseMemoryTest {
|
||||
|
||||
@Autowired
|
||||
public OpenAiMemoryChatAgentIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
|
||||
public ChatMemoryShortTermSystemPromptIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
|
||||
StreamingChatAgent streamingChatAgent) {
|
||||
super(relevancyEvaluator, chatAgent, streamingChatAgent);
|
||||
}
|
||||
@@ -80,7 +80,7 @@ public class OpenAiMemoryChatAgentIT extends BaseMemoryTest {
|
||||
|
||||
return DefaultChatAgent.builder(chatClient)
|
||||
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
|
||||
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
|
||||
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
|
||||
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
|
||||
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
|
||||
.build();
|
||||
@@ -72,9 +72,9 @@ import org.springframework.core.io.Resource;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Testcontainers
|
||||
@SpringBootTest(classes = TextChatHistoryChatAgent3IT.Config.class)
|
||||
@SpringBootTest(classes = LongShortTermChatMemoryWithRagIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class TextChatHistoryChatAgent3IT {
|
||||
public class LongShortTermChatMemoryWithRagIT {
|
||||
|
||||
protected final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -191,7 +191,7 @@ public class TextChatHistoryChatAgent3IT {
|
||||
new VectorStoreChatMemoryRetriever(vectorStore, 10,
|
||||
Map.of(TransformerContentType.LONG_TERM_MEMORY, ""))))
|
||||
|
||||
.withDocumentPostProcessors(List.of(
|
||||
.withContentPostProcessors(List.of(
|
||||
new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000,
|
||||
Set.of(TransformerContentType.SHORT_TERM_MEMORY)),
|
||||
new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000,
|
||||
@@ -16,16 +16,24 @@
|
||||
|
||||
package org.springframework.ai.openai.chat.agent;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.QdrantGrpcClient;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.qdrant.QdrantContainer;
|
||||
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.agent.ChatAgent;
|
||||
import org.springframework.ai.chat.agent.DefaultChatAgent;
|
||||
import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor;
|
||||
import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.transformer.PromptContext;
|
||||
import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor;
|
||||
import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever;
|
||||
import org.springframework.ai.embedding.EmbeddingClient;
|
||||
import org.springframework.ai.evaluation.EvaluationRequest;
|
||||
import org.springframework.ai.evaluation.EvaluationResponse;
|
||||
@@ -36,8 +44,8 @@ import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.reader.JsonReader;
|
||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
@@ -45,12 +53,18 @@ import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Testcontainers
|
||||
@SpringBootTest(classes = OpenAiDefaultChatAgentIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class OpenAiDefaultChatAgentIT {
|
||||
|
||||
private static final String COLLECTION_NAME = "test_collection";
|
||||
|
||||
private static final int QDRANT_GRPC_PORT = 6334;
|
||||
|
||||
@Container
|
||||
static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.7.4");
|
||||
|
||||
private final ChatClient chatClient;
|
||||
|
||||
private final VectorStore vectorStore;
|
||||
@@ -110,8 +124,11 @@ public class OpenAiDefaultChatAgentIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public VectorStore vectorStore(EmbeddingClient embeddingClient) {
|
||||
return new SimpleVectorStore(embeddingClient);
|
||||
public VectorStore qdrantVectorStore(EmbeddingClient embeddingClient) {
|
||||
QdrantClient qdrantClient = new QdrantClient(QdrantGrpcClient
|
||||
.newBuilder(qdrantContainer.getHost(), qdrantContainer.getMappedPort(QDRANT_GRPC_PORT), false)
|
||||
.build());
|
||||
return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingClient);
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -113,7 +113,7 @@ public class DefaultChatAgent implements ChatAgent {
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultChatAgentBuilder withDocumentPostProcessors(List<PromptTransformer> documentPostProcessors) {
|
||||
public DefaultChatAgentBuilder withContentPostProcessors(List<PromptTransformer> documentPostProcessors) {
|
||||
this.documentPostProcessors = documentPostProcessors;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -22,17 +22,17 @@ import org.springframework.ai.chat.messages.Message;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*
|
||||
*
|
||||
*/
|
||||
public interface ChatMemory {
|
||||
|
||||
default void add(String conversationId, Message messages) {
|
||||
this.add(conversationId, List.of(messages));
|
||||
default void add(String conversationId, Message message) {
|
||||
this.add(conversationId, List.of(message));
|
||||
}
|
||||
|
||||
void add(String conversationId, List<Message> messages);
|
||||
|
||||
List<Message> get(String conversationId);
|
||||
List<Message> get(String conversationId, int lastN);
|
||||
|
||||
void clear(String conversationId);
|
||||
|
||||
|
||||
@@ -41,19 +41,26 @@ public class ChatMemoryRetriever implements PromptTransformer {
|
||||
*/
|
||||
private final Map<String, Object> additionalMetadata;
|
||||
|
||||
private final int maxHistorySize;
|
||||
|
||||
public ChatMemoryRetriever(ChatMemory chatHistory) {
|
||||
this(chatHistory, Map.of());
|
||||
}
|
||||
|
||||
public ChatMemoryRetriever(ChatMemory chatHistory, Map<String, Object> additionalMetadata) {
|
||||
this(chatHistory, 1000, additionalMetadata);
|
||||
}
|
||||
|
||||
public ChatMemoryRetriever(ChatMemory chatHistory, int maxHistorySize, Map<String, Object> additionalMetadata) {
|
||||
this.chatHistory = chatHistory;
|
||||
this.additionalMetadata = additionalMetadata;
|
||||
this.maxHistorySize = maxHistorySize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptContext transform(PromptContext promptContext) {
|
||||
|
||||
List<Message> messageHistory = this.chatHistory.get(promptContext.getConversationId());
|
||||
List<Message> messageHistory = this.chatHistory.get(promptContext.getConversationId(), maxHistorySize);
|
||||
|
||||
List<Content> historyContent = (messageHistory != null)
|
||||
? messageHistory.stream().filter(m -> m.getMessageType() != MessageType.SYSTEM).map(m -> {
|
||||
|
||||
@@ -37,8 +37,9 @@ public class InMemoryChatMemory implements ChatMemory {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> get(String conversationId) {
|
||||
return this.conversationHistory.get(conversationId);
|
||||
public List<Message> get(String conversationId, int lastN) {
|
||||
List<Message> all = this.conversationHistory.get(conversationId);
|
||||
return all != null ? all.stream().skip(Math.max(0, all.size() - lastN)).toList() : List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -83,8 +83,6 @@ public class LastMaxTokenSizeContentTransformer implements PromptTransformer {
|
||||
|
||||
List<Content> datum = this.doGetDatumToModify(promptContext);
|
||||
|
||||
// int totalSize = this.tokenCountEstimator.estimate(nonSystemChatMessages) -
|
||||
// retrievalRequest.getTokenRunningTotal();
|
||||
int totalSize = this.doEstimateTokenCount(datum);
|
||||
|
||||
if (totalSize <= this.maxTokenSize) {
|
||||
@@ -106,7 +104,6 @@ public class LastMaxTokenSizeContentTransformer implements PromptTransformer {
|
||||
|
||||
while (index < datum.size() && totalSize > this.maxTokenSize) {
|
||||
Content oldDatum = datum.get(index++);
|
||||
// int oldMessageTokenSize = this.tokenCountEstimator.estimate(oldDatum);
|
||||
int oldMessageTokenSize = this.doEstimateTokenCount(oldDatum);
|
||||
totalSize = totalSize - oldMessageTokenSize;
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ public class RelevancyEvaluator implements Evaluator {
|
||||
.filter(node -> node != null && node.getContent() instanceof String)
|
||||
.map(node -> (Content) node)
|
||||
.map(Content::getContent)
|
||||
.collect(Collectors.joining("\n"));
|
||||
.collect(Collectors.joining(System.lineSeparator()));
|
||||
return supportingData;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ import static org.mockito.Mockito.when;
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class ChatHistoryTests {
|
||||
public class ChatMemoryTests {
|
||||
|
||||
@Mock
|
||||
ChatClient chatClient;
|
||||
@@ -57,13 +57,13 @@ public class ChatHistoryTests {
|
||||
ArgumentCaptor<Prompt> promptCaptor;
|
||||
|
||||
@Test
|
||||
public void chatAgentMessageHistory() {
|
||||
public void chatMemoryMessageListAugmentor() {
|
||||
|
||||
ChatMemory chatHistory = new InMemoryChatMemory();
|
||||
|
||||
DefaultChatAgent chatAgent = DefaultChatAgent.builder(chatClient)
|
||||
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
|
||||
.withDocumentPostProcessors(
|
||||
.withContentPostProcessors(
|
||||
List.of(new LastMaxTokenSizeContentTransformer(new JTokkitTokenCountEstimator(), 10)))
|
||||
.withAugmentors(List.of(new MessageChatMemoryAugmentor()))
|
||||
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
|
||||
@@ -73,13 +73,13 @@ public class ChatHistoryTests {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void chatAgentTextHistory() {
|
||||
public void chatMemorySystemPromptAugmentor() {
|
||||
|
||||
ChatMemory chatHistory = new InMemoryChatMemory();
|
||||
|
||||
DefaultChatAgent chatAgent = DefaultChatAgent.builder(chatClient)
|
||||
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
|
||||
.withDocumentPostProcessors(
|
||||
.withContentPostProcessors(
|
||||
List.of(new LastMaxTokenSizeContentTransformer(new JTokkitTokenCountEstimator(), 10)))
|
||||
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
|
||||
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
|
||||
@@ -106,15 +106,10 @@ public class ChatHistoryTests {
|
||||
|
||||
assertThat(response1.getChatResponse().getResult().getOutput().getContent()).isEqualTo("assistant:1");
|
||||
|
||||
// List<Message> messages2 = promptCaptor.getValue().getInstructions();
|
||||
// assertThat(messages2)
|
||||
// .isEqualTo(List.of(new UserMessage("user:1"), new UserMessage("user:2"), new UserMessage("user:3"),
|
||||
// new UserMessage("user:4"), new UserMessage("user:5")));
|
||||
|
||||
List<Content> contents = response1.getPromptContext().getContents();
|
||||
assertThat(contents).hasSize(0);
|
||||
|
||||
List<Message> history = chatHistory.get("test-session-id");
|
||||
List<Message> history = chatHistory.get("test-session-id", 1000);
|
||||
assertThat(history).hasSize(6);
|
||||
|
||||
AgentResponse response2 = chatAgent.call(PromptContext.builder()
|
||||
@@ -125,7 +120,7 @@ public class ChatHistoryTests {
|
||||
|
||||
assertThat(response2.getChatResponse().getResult().getOutput().getContent()).isEqualTo("assistant:2");
|
||||
|
||||
history = chatHistory.get("test-session-id");
|
||||
history = chatHistory.get("test-session-id", 1000);
|
||||
assertThat(history).hasSize(10);
|
||||
|
||||
contents = response2.getPromptContext().getContents();
|
||||
@@ -139,7 +134,7 @@ public class ChatHistoryTests {
|
||||
.withPrompt(new Prompt(List.of(new UserMessage("user:9")))).build());
|
||||
assertThat(response3.getChatResponse().getResult().getOutput().getContent()).isEqualTo("assistant:3");
|
||||
|
||||
history = chatHistory.get("test-session-id");
|
||||
history = chatHistory.get("test-session-id", 1000);
|
||||
assertThat(history).hasSize(12);
|
||||
|
||||
contents = response3.getPromptContext().getContents();
|
||||
Reference in New Issue
Block a user