Address some review comments

This commit is contained in:
Christian Tzolov
2024-04-26 14:12:40 +02:00
parent 3857b83cd1
commit 8796896158
12 changed files with 64 additions and 47 deletions

View File

@@ -49,9 +49,9 @@ import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
@Testcontainers
@SpringBootTest(classes = TextChatHistoryChatAgentIT.Config.class)
@SpringBootTest(classes = ChatMemoryLongTermSystemPromptIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class TextChatHistoryChatAgentIT extends BaseMemoryTest {
public class ChatMemoryLongTermSystemPromptIT extends BaseMemoryTest {
private static final String COLLECTION_NAME = "test_collection";
@@ -61,7 +61,7 @@ public class TextChatHistoryChatAgentIT extends BaseMemoryTest {
static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.7.4");
@Autowired
public TextChatHistoryChatAgentIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
public ChatMemoryLongTermSystemPromptIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
StreamingChatAgent streamingChatAgent) {
super(relevancyEvaluator, chatAgent, streamingChatAgent);
}
@@ -103,7 +103,7 @@ public class TextChatHistoryChatAgentIT extends BaseMemoryTest {
return DefaultChatAgent.builder(chatClient)
.withRetrievers(List.of(new VectorStoreChatMemoryRetriever(vectorStore, 10)))
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
.withChatAgentListeners(List.of(new VectorStoreChatMemoryAgentListener(vectorStore)))
.build();

View File

@@ -40,12 +40,12 @@ import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
@SpringBootTest(classes = MessageChatHistoryChatAgentIT.Config.class)
@SpringBootTest(classes = ChatMemoryShortTermMessageListIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class MessageChatHistoryChatAgentIT extends BaseMemoryTest {
public class ChatMemoryShortTermMessageListIT extends BaseMemoryTest {
@Autowired
public MessageChatHistoryChatAgentIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
public ChatMemoryShortTermMessageListIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
StreamingChatAgent streamingChatAgent) {
super(relevancyEvaluator, chatAgent, streamingChatAgent);
}
@@ -79,7 +79,7 @@ public class MessageChatHistoryChatAgentIT extends BaseMemoryTest {
return DefaultChatAgent.builder(chatClient)
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withAugmentors(List.of(new MessageChatMemoryAugmentor()))
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
.build();

View File

@@ -41,12 +41,12 @@ import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
@SpringBootTest(classes = OpenAiMemoryChatAgentIT.Config.class)
@SpringBootTest(classes = ChatMemoryShortTermSystemPromptIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiMemoryChatAgentIT extends BaseMemoryTest {
public class ChatMemoryShortTermSystemPromptIT extends BaseMemoryTest {
@Autowired
public OpenAiMemoryChatAgentIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
public ChatMemoryShortTermSystemPromptIT(RelevancyEvaluator relevancyEvaluator, ChatAgent chatAgent,
StreamingChatAgent streamingChatAgent) {
super(relevancyEvaluator, chatAgent, streamingChatAgent);
}
@@ -80,7 +80,7 @@ public class OpenAiMemoryChatAgentIT extends BaseMemoryTest {
return DefaultChatAgent.builder(chatClient)
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
.build();

View File

@@ -72,9 +72,9 @@ import org.springframework.core.io.Resource;
import static org.assertj.core.api.Assertions.assertThat;
@Testcontainers
@SpringBootTest(classes = TextChatHistoryChatAgent3IT.Config.class)
@SpringBootTest(classes = LongShortTermChatMemoryWithRagIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class TextChatHistoryChatAgent3IT {
public class LongShortTermChatMemoryWithRagIT {
protected final Logger logger = LoggerFactory.getLogger(getClass());
@@ -191,7 +191,7 @@ public class TextChatHistoryChatAgent3IT {
new VectorStoreChatMemoryRetriever(vectorStore, 10,
Map.of(TransformerContentType.LONG_TERM_MEMORY, ""))))
.withDocumentPostProcessors(List.of(
.withContentPostProcessors(List.of(
new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000,
Set.of(TransformerContentType.SHORT_TERM_MEMORY)),
new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000,

View File

@@ -16,16 +16,24 @@
package org.springframework.ai.openai.chat.agent;
import java.util.List;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.qdrant.QdrantContainer;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.agent.ChatAgent;
import org.springframework.ai.chat.agent.DefaultChatAgent;
import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor;
import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.transformer.PromptContext;
import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor;
import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.evaluation.EvaluationRequest;
import org.springframework.ai.evaluation.EvaluationResponse;
@@ -36,8 +44,8 @@ import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
@@ -45,12 +53,18 @@ import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.Resource;
import java.util.List;
@Testcontainers
@SpringBootTest(classes = OpenAiDefaultChatAgentIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiDefaultChatAgentIT {
private static final String COLLECTION_NAME = "test_collection";
private static final int QDRANT_GRPC_PORT = 6334;
@Container
static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.7.4");
private final ChatClient chatClient;
private final VectorStore vectorStore;
@@ -110,8 +124,11 @@ public class OpenAiDefaultChatAgentIT {
}
@Bean
public VectorStore vectorStore(EmbeddingClient embeddingClient) {
return new SimpleVectorStore(embeddingClient);
public VectorStore qdrantVectorStore(EmbeddingClient embeddingClient) {
QdrantClient qdrantClient = new QdrantClient(QdrantGrpcClient
.newBuilder(qdrantContainer.getHost(), qdrantContainer.getMappedPort(QDRANT_GRPC_PORT), false)
.build());
return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingClient);
}
@Bean

View File

@@ -113,7 +113,7 @@ public class DefaultChatAgent implements ChatAgent {
return this;
}
public DefaultChatAgentBuilder withDocumentPostProcessors(List<PromptTransformer> documentPostProcessors) {
public DefaultChatAgentBuilder withContentPostProcessors(List<PromptTransformer> documentPostProcessors) {
this.documentPostProcessors = documentPostProcessors;
return this;
}

View File

@@ -22,17 +22,17 @@ import org.springframework.ai.chat.messages.Message;
/**
* @author Christian Tzolov
*
*
*/
public interface ChatMemory {
default void add(String conversationId, Message messages) {
this.add(conversationId, List.of(messages));
default void add(String conversationId, Message message) {
this.add(conversationId, List.of(message));
}
void add(String conversationId, List<Message> messages);
List<Message> get(String conversationId);
List<Message> get(String conversationId, int lastN);
void clear(String conversationId);

View File

@@ -41,19 +41,26 @@ public class ChatMemoryRetriever implements PromptTransformer {
*/
private final Map<String, Object> additionalMetadata;
private final int maxHistorySize;
public ChatMemoryRetriever(ChatMemory chatHistory) {
this(chatHistory, Map.of());
}
public ChatMemoryRetriever(ChatMemory chatHistory, Map<String, Object> additionalMetadata) {
this(chatHistory, 1000, additionalMetadata);
}
public ChatMemoryRetriever(ChatMemory chatHistory, int maxHistorySize, Map<String, Object> additionalMetadata) {
this.chatHistory = chatHistory;
this.additionalMetadata = additionalMetadata;
this.maxHistorySize = maxHistorySize;
}
@Override
public PromptContext transform(PromptContext promptContext) {
List<Message> messageHistory = this.chatHistory.get(promptContext.getConversationId());
List<Message> messageHistory = this.chatHistory.get(promptContext.getConversationId(), maxHistorySize);
List<Content> historyContent = (messageHistory != null)
? messageHistory.stream().filter(m -> m.getMessageType() != MessageType.SYSTEM).map(m -> {

View File

@@ -37,8 +37,9 @@ public class InMemoryChatMemory implements ChatMemory {
}
@Override
public List<Message> get(String conversationId) {
return this.conversationHistory.get(conversationId);
public List<Message> get(String conversationId, int lastN) {
List<Message> all = this.conversationHistory.get(conversationId);
return all != null ? all.stream().skip(Math.max(0, all.size() - lastN)).toList() : List.of();
}
@Override

View File

@@ -83,8 +83,6 @@ public class LastMaxTokenSizeContentTransformer implements PromptTransformer {
List<Content> datum = this.doGetDatumToModify(promptContext);
// int totalSize = this.tokenCountEstimator.estimate(nonSystemChatMessages) -
// retrievalRequest.getTokenRunningTotal();
int totalSize = this.doEstimateTokenCount(datum);
if (totalSize <= this.maxTokenSize) {
@@ -106,7 +104,6 @@ public class LastMaxTokenSizeContentTransformer implements PromptTransformer {
while (index < datum.size() && totalSize > this.maxTokenSize) {
Content oldDatum = datum.get(index++);
// int oldMessageTokenSize = this.tokenCountEstimator.estimate(oldDatum);
int oldMessageTokenSize = this.doEstimateTokenCount(oldDatum);
totalSize = totalSize - oldMessageTokenSize;
}

View File

@@ -66,7 +66,7 @@ public class RelevancyEvaluator implements Evaluator {
.filter(node -> node != null && node.getContent() instanceof String)
.map(node -> (Content) node)
.map(Content::getContent)
.collect(Collectors.joining("\n"));
.collect(Collectors.joining(System.lineSeparator()));
return supportingData;
}

View File

@@ -45,7 +45,7 @@ import static org.mockito.Mockito.when;
* @author Christian Tzolov
*/
@ExtendWith(MockitoExtension.class)
public class ChatHistoryTests {
public class ChatMemoryTests {
@Mock
ChatClient chatClient;
@@ -57,13 +57,13 @@ public class ChatHistoryTests {
ArgumentCaptor<Prompt> promptCaptor;
@Test
public void chatAgentMessageHistory() {
public void chatMemoryMessageListAugmentor() {
ChatMemory chatHistory = new InMemoryChatMemory();
DefaultChatAgent chatAgent = DefaultChatAgent.builder(chatClient)
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
.withDocumentPostProcessors(
.withContentPostProcessors(
List.of(new LastMaxTokenSizeContentTransformer(new JTokkitTokenCountEstimator(), 10)))
.withAugmentors(List.of(new MessageChatMemoryAugmentor()))
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
@@ -73,13 +73,13 @@ public class ChatHistoryTests {
}
@Test
public void chatAgentTextHistory() {
public void chatMemorySystemPromptAugmentor() {
ChatMemory chatHistory = new InMemoryChatMemory();
DefaultChatAgent chatAgent = DefaultChatAgent.builder(chatClient)
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
.withDocumentPostProcessors(
.withContentPostProcessors(
List.of(new LastMaxTokenSizeContentTransformer(new JTokkitTokenCountEstimator(), 10)))
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
@@ -106,15 +106,10 @@ public class ChatHistoryTests {
assertThat(response1.getChatResponse().getResult().getOutput().getContent()).isEqualTo("assistant:1");
// List<Message> messages2 = promptCaptor.getValue().getInstructions();
// assertThat(messages2)
// .isEqualTo(List.of(new UserMessage("user:1"), new UserMessage("user:2"), new UserMessage("user:3"),
// new UserMessage("user:4"), new UserMessage("user:5")));
List<Content> contents = response1.getPromptContext().getContents();
assertThat(contents).hasSize(0);
List<Message> history = chatHistory.get("test-session-id");
List<Message> history = chatHistory.get("test-session-id", 1000);
assertThat(history).hasSize(6);
AgentResponse response2 = chatAgent.call(PromptContext.builder()
@@ -125,7 +120,7 @@ public class ChatHistoryTests {
assertThat(response2.getChatResponse().getResult().getOutput().getContent()).isEqualTo("assistant:2");
history = chatHistory.get("test-session-id");
history = chatHistory.get("test-session-id", 1000);
assertThat(history).hasSize(10);
contents = response2.getPromptContext().getContents();
@@ -139,7 +134,7 @@ public class ChatHistoryTests {
.withPrompt(new Prompt(List.of(new UserMessage("user:9")))).build());
assertThat(response3.getChatResponse().getResult().getOutput().getContent()).isEqualTo("assistant:3");
history = chatHistory.get("test-session-id");
history = chatHistory.get("test-session-id", 1000);
assertThat(history).hasSize(12);
contents = response3.getPromptContext().getContents();